summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDarren Wihandi <65404740+fairywreath@users.noreply.github.com>2025-01-15 18:50:56 -0500
committerGitHub <noreply@github.com>2025-01-15 15:50:56 -0800
commit9b977e59cf786bbb000b3b868b126c2b9a17d3f3 (patch)
tree95ea341da78da12ac831f48a4d0827aa3f444d5c
parentd9d0b4f03277027690a909a76b78d1622ac13498 (diff)
Reuse code for Metal and WGSL entry point legalization (#6063)
* Refactor to reuse common for metal and wgsl entry point legalization * refactor system val work item * refactor simplify user names * clean up fix semantic field of struct * improve code layout * split wgsl/metal to seperate classes and cleanup * remove extra includes * remove dead code comments * minor cleanup * squash merge from master and resolve conflict * apply metal spec const thread count changes * Revert "apply metal spec const thread count changes" This reverts commit c42d707fd25ee0328598650d3235cd2322810ccc. * Revert "squash merge from master and resolve conflict" This reverts commit 06db88ef7001bdfe93fb23af35af0d026b255dee. * Merge remote-tracking branch 'origin/master' * apply metal spec const thread count changes * Revert "apply metal spec const thread count changes" This reverts commit 3b9e6f53cee2e6076ac2b7a0d015a1ed2cbbd627. * Revert "Merge remote-tracking branch 'origin/master'" This reverts commit 99869d573a46dadeb24445405f5a1e37a8e03d0d. * apply metal spec const thread count changes --------- Co-authored-by: Yong He <yonghe@outlook.com>
-rw-r--r--source/slang/slang-ir-legalize-varying-params.cpp2389
-rw-r--r--source/slang/slang-ir-legalize-varying-params.h22
-rw-r--r--source/slang/slang-ir-metal-legalize.cpp1959
-rw-r--r--source/slang/slang-ir-wgsl-legalize.cpp1637
4 files changed, 2538 insertions, 3469 deletions
diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp
index a98290545..69d62c8bf 100644
--- a/source/slang/slang-ir-legalize-varying-params.cpp
+++ b/source/slang/slang-ir-legalize-varying-params.cpp
@@ -6,6 +6,8 @@
#include "slang-ir-util.h"
#include "slang-parameter-binding.h"
+#include <set>
+
namespace Slang
{
// Convert semantic name (ignores case) into equivlent `SystemValueSemanticName`
@@ -1560,4 +1562,2391 @@ void depointerizeInputParams(IRFunc* entryPointFunc)
}
}
+
+class LegalizeShaderEntryPointContext
+{
+public:
+ void legalizeEntryPoints(List<EntryPointInfo>& entryPoints)
+ {
+ for (auto entryPoint : entryPoints)
+ legalizeEntryPoint(entryPoint);
+ removeSemanticLayoutsFromLegalizedStructs();
+ }
+
+protected:
+ LegalizeShaderEntryPointContext(IRModule* module, DiagnosticSink* sink, bool hoistParameters)
+ : m_module(module), m_sink(sink), hoistParameters(hoistParameters)
+ {
+ }
+
+ IRModule* m_module;
+ DiagnosticSink* m_sink;
+
+ struct SystemValueInfo
+ {
+ String systemValueName;
+ SystemValueSemanticName systemValueNameEnum;
+ ShortList<IRType*> permittedTypes;
+
+ bool isUnsupported = false;
+ bool isSpecial = false;
+ };
+
+ struct SystemValLegalizationWorkItem
+ {
+ IRInst* var;
+ IRType* varType;
+
+ String attrName;
+ UInt attrIndex;
+ };
+
+ virtual SystemValueInfo getSystemValueInfo(
+ String inSemanticName,
+ String* optionalSemanticIndex,
+ IRInst* parentVar) const = 0;
+
+ virtual List<SystemValLegalizationWorkItem> collectSystemValFromEntryPoint(
+ EntryPointInfo entryPoint) const = 0;
+
+ virtual void flattenNestedStructsTransferKeyDecorations(IRInst* newKey, IRInst* oldKey)
+ const = 0;
+
+ virtual UnownedStringSlice getUserSemanticNameSlice(String& loweredName, bool isUserSemantic)
+ const = 0;
+
+ virtual void addFragmentShaderReturnValueDecoration(
+ IRBuilder& builder,
+ IRInst* returnValueStructKey) const = 0;
+
+
+ virtual IRVarLayout* handleGeometryStageParameterVarLayout(
+ IRBuilder& builder,
+ IRVarLayout* paramVarLayout) const
+ {
+ SLANG_UNUSED(builder);
+ return paramVarLayout;
+ }
+
+ virtual void handleSpecialSystemValue(
+ const EntryPointInfo& entryPoint,
+ SystemValLegalizationWorkItem& workItem,
+ const SystemValueInfo& info,
+ IRBuilder& builder)
+ {
+ SLANG_UNUSED(entryPoint);
+ SLANG_UNUSED(workItem);
+ SLANG_UNUSED(info);
+ SLANG_UNUSED(builder);
+ }
+
+ virtual void legalizeAmplificationStageEntryPoint(const EntryPointInfo& entryPoint) const
+ {
+ SLANG_UNUSED(entryPoint);
+ }
+
+ virtual void legalizeMeshStageEntryPoint(const EntryPointInfo& entryPoint) const
+ {
+ SLANG_UNUSED(entryPoint);
+ }
+
+
+ std::optional<SystemValLegalizationWorkItem> tryToMakeSystemValWorkItem(
+ IRInst* var,
+ IRType* varType) const
+ {
+ if (auto semanticDecoration = var->findDecoration<IRSemanticDecoration>())
+ {
+ if (semanticDecoration->getSemanticName().startsWithCaseInsensitive(toSlice("sv_")))
+ {
+ return {
+ {var,
+ varType,
+ String(semanticDecoration->getSemanticName()).toLower(),
+ (UInt)semanticDecoration->getSemanticIndex()}};
+ }
+ }
+
+ auto layoutDecor = var->findDecoration<IRLayoutDecoration>();
+ if (!layoutDecor)
+ return {};
+ auto sysValAttr = layoutDecor->findAttr<IRSystemValueSemanticAttr>();
+ if (!sysValAttr)
+ return {};
+ auto semanticName = String(sysValAttr->getName());
+ auto sysAttrIndex = sysValAttr->getIndex();
+
+ return {{var, varType, semanticName, sysAttrIndex}};
+ }
+
+ void legalizeSystemValue(EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem)
+ {
+ IRBuilder builder(entryPoint.entryPointFunc);
+
+ auto var = workItem.var;
+ auto varType = workItem.varType;
+ auto semanticName = workItem.attrName;
+
+ auto indexAsString = String(workItem.attrIndex);
+ SystemValueInfo info = getSystemValueInfo(semanticName, &indexAsString, var);
+ if (info.isSpecial)
+ {
+ handleSpecialSystemValue(entryPoint, workItem, info, builder);
+ }
+
+ if (info.isUnsupported)
+ {
+ reportUnsupportedSystemAttribute(var, semanticName);
+ return;
+ }
+ if (!info.permittedTypes.getCount())
+ return;
+
+ builder.addTargetSystemValueDecoration(var, info.systemValueName.getUnownedSlice());
+
+ bool varTypeIsPermitted = false;
+ for (auto& permittedType : info.permittedTypes)
+ {
+ varTypeIsPermitted = varTypeIsPermitted || permittedType == varType;
+ }
+
+ if (!varTypeIsPermitted)
+ {
+ // Note: we do not currently prefer any conversion
+ // example:
+ // * allowed types for semantic: `float4`, `uint4`, `int4`
+ // * user used, `float2`
+ // * Slang will equally prefer `float4` to `uint4` to `int4`.
+ // This means the type may lose data if slang selects `uint4` or `int4`.
+ bool foundAConversion = false;
+ for (auto permittedType : info.permittedTypes)
+ {
+ var->setFullType(permittedType);
+ builder.setInsertBefore(
+ entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst());
+
+ // get uses before we `tryConvertValue` since this creates a new use
+ List<IRUse*> uses;
+ for (auto use = var->firstUse; use; use = use->nextUse)
+ uses.add(use);
+
+ auto convertedValue = tryConvertValue(builder, var, varType);
+ if (convertedValue == nullptr)
+ continue;
+
+ foundAConversion = true;
+ copyNameHintAndDebugDecorations(convertedValue, var);
+
+ for (auto use : uses)
+ builder.replaceOperand(use, convertedValue);
+ }
+ if (!foundAConversion)
+ {
+ // If we can't convert the value, report an error.
+ for (auto permittedType : info.permittedTypes)
+ {
+ StringBuilder typeNameSB;
+ getTypeNameHint(typeNameSB, permittedType);
+ m_sink->diagnose(
+ var->sourceLoc,
+ Diagnostics::systemValueTypeIncompatible,
+ semanticName,
+ typeNameSB.produceString());
+ }
+ }
+ }
+ }
+
+private:
+ const bool hoistParameters;
+ HashSet<IRStructField*> semanticInfoToRemove;
+
+ void removeSemanticLayoutsFromLegalizedStructs()
+ {
+ // Metal and WGSL does not allow duplicate attributes to appear in the same shader.
+ // If we emit our own struct with `[[color(0)]`, all existing uses of `[[color(0)]]`
+ // must be removed.
+ for (auto field : semanticInfoToRemove)
+ {
+ auto key = field->getKey();
+ // Some decorations appear twice, destroy all found
+ for (;;)
+ {
+ if (auto semanticDecor = key->findDecoration<IRSemanticDecoration>())
+ {
+ semanticDecor->removeAndDeallocate();
+ continue;
+ }
+ else if (auto layoutDecor = key->findDecoration<IRLayoutDecoration>())
+ {
+ layoutDecor->removeAndDeallocate();
+ continue;
+ }
+ break;
+ }
+ }
+ }
+
+ void hoistEntryPointParameterFromStruct(EntryPointInfo entryPoint)
+ {
+ // If an entry point has a input parameter with a struct type, we want to hoist out
+ // all the fields of the struct type to be individual parameters of the entry point.
+ // This will canonicalize the entry point signature, so we can handle all cases uniformly.
+
+ // For example, given an entry point:
+ // ```
+ // struct VertexInput { float3 pos; float 2 uv; int vertexId : SV_VertexID};
+ // void main(VertexInput vin) { ... }
+ // ```
+ // We will transform it to:
+ // ```
+ // void main(float3 pos, float2 uv, int vertexId : SV_VertexID) {
+ // VertexInput vin = {pos,uv,vertexId};
+ // ...
+ // }
+ // ```
+
+ auto func = entryPoint.entryPointFunc;
+ List<IRParam*> paramsToProcess;
+ for (auto param : func->getParams())
+ {
+ if (as<IRStructType>(param->getDataType()))
+ {
+ paramsToProcess.add(param);
+ }
+ }
+
+ IRBuilder builder(func);
+ builder.setInsertBefore(func);
+ for (auto param : paramsToProcess)
+ {
+ auto structType = as<IRStructType>(param->getDataType());
+ builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
+ auto varLayout = findVarLayout(param);
+
+ // If `param` already has a semantic, we don't want to hoist its fields out.
+ if (varLayout->findSystemValueSemanticAttr() != nullptr ||
+ param->findDecoration<IRSemanticDecoration>())
+ continue;
+
+ IRStructTypeLayout* structTypeLayout = nullptr;
+ if (varLayout)
+ structTypeLayout = as<IRStructTypeLayout>(varLayout->getTypeLayout());
+ Index fieldIndex = 0;
+ List<IRInst*> fieldParams;
+ for (auto field : structType->getFields())
+ {
+ auto fieldParam = builder.emitParam(field->getFieldType());
+ IRCloneEnv cloneEnv;
+ cloneInstDecorationsAndChildren(
+ &cloneEnv,
+ builder.getModule(),
+ field->getKey(),
+ fieldParam);
+
+ IRVarLayout* fieldLayout =
+ structTypeLayout ? structTypeLayout->getFieldLayout(fieldIndex) : nullptr;
+ if (varLayout)
+ {
+ IRVarLayout::Builder varLayoutBuilder(&builder, fieldLayout->getTypeLayout());
+ varLayoutBuilder.cloneEverythingButOffsetsFrom(fieldLayout);
+ for (auto offsetAttr : fieldLayout->getOffsetAttrs())
+ {
+ auto parentOffsetAttr =
+ varLayout->findOffsetAttr(offsetAttr->getResourceKind());
+ UInt parentOffset = parentOffsetAttr ? parentOffsetAttr->getOffset() : 0;
+ UInt parentSpace = parentOffsetAttr ? parentOffsetAttr->getSpace() : 0;
+ auto resInfo =
+ varLayoutBuilder.findOrAddResourceInfo(offsetAttr->getResourceKind());
+ resInfo->offset = parentOffset + offsetAttr->getOffset();
+ resInfo->space = parentSpace + offsetAttr->getSpace();
+ }
+ builder.addLayoutDecoration(fieldParam, varLayoutBuilder.build());
+ }
+ param->insertBefore(fieldParam);
+ fieldParams.add(fieldParam);
+ fieldIndex++;
+ }
+ builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
+ auto reconstructedParam =
+ builder.emitMakeStruct(structType, fieldParams.getCount(), fieldParams.getBuffer());
+ param->replaceUsesWith(reconstructedParam);
+ param->removeFromParent();
+ }
+ fixUpFuncType(func);
+ }
+
+ // Flattens all struct parameters of an entryPoint to ensure parameters are a flat struct
+ void flattenInputParameters(EntryPointInfo entryPoint)
+ {
+ // Goal is to ensure we have a flattened IRParam (0 nested IRStructType members).
+ /*
+ // Assume the following code
+ struct NestedFragment
+ {
+ float2 p3;
+ };
+ struct Fragment
+ {
+ float4 p1;
+ float3 p2;
+ NestedFragment p3_nested;
+ };
+
+ // Fragment flattens into
+ struct Fragment
+ {
+ float4 p1;
+ float3 p2;
+ float2 p3;
+ };
+ */
+
+ // This is important since Metal and WGSL does not allow semantic's on a struct
+ /*
+ // Assume the following code
+ struct NestedFragment1
+ {
+ float2 p3;
+ };
+ struct Fragment1
+ {
+ float4 p1 : SV_TARGET0;
+ float3 p2 : SV_TARGET1;
+ NestedFragment p3_nested : SV_TARGET2; // error, semantic on struct
+ };
+
+ */
+
+ // Metal does allow semantics on members of a nested struct but we are avoiding this
+ // approach since there are senarios where legalization (and verification) is
+ // hard/expensive without creating a flat struct:
+ // 1. Entry points may share structs, semantics may be inconsistent across entry points
+ // 2. Multiple of the same struct may be used in a param list
+ //
+ // WGSL does NOT allow semantics on members of a nested struct.
+ /*
+ // Assume the following code
+ struct NestedFragment
+ {
+ float2 p3;
+ };
+ struct Fragment
+ {
+ float4 p1 : SV_TARGET0;
+ NestedFragment p2 : SV_TARGET1;
+ NestedFragment p3 : SV_TARGET2;
+ };
+
+ // Legalized without flattening -- abandoned
+ struct NestedFragment1
+ {
+ float2 p3 : SV_TARGET1;
+ };
+ struct NestedFragment2
+ {
+ float2 p3 : SV_TARGET2;
+ };
+ struct Fragment
+ {
+ float4 p1 : SV_TARGET0;
+ NestedFragment1 p2;
+ NestedFragment2 p3;
+ };
+
+ // Legalized with flattening -- current approach
+ struct Fragment
+ {
+ float4 p1 : SV_TARGET0;
+ float2 p2 : SV_TARGET1;
+ float2 p3 : SV_TARGET2;
+ };
+ */
+
+ auto func = entryPoint.entryPointFunc;
+ bool modified = false;
+ for (auto param : func->getParams())
+ {
+ auto layout = findVarLayout(param);
+ if (!layout)
+ continue;
+ if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput))
+ continue;
+ if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration))
+ continue;
+ // If we find a IRParam with a IRStructType member, we need to flatten the entire
+ // IRParam
+ if (auto structType = as<IRStructType>(param->getDataType()))
+ {
+ IRBuilder builder(func);
+ MapStructToFlatStruct mapOldFieldToNewField;
+
+ // Flatten struct if we have nested IRStructType
+ auto flattenedStruct = maybeFlattenNestedStructs(
+ builder,
+ structType,
+ mapOldFieldToNewField,
+ semanticInfoToRemove);
+
+ // Validate/rearange all semantics which overlap in our flat struct.
+ fixFieldSemanticsOfFlatStruct(flattenedStruct);
+ ensureStructHasUserSemantic<LayoutResourceKind::VaryingInput>(
+ flattenedStruct,
+ layout);
+ if (flattenedStruct != structType)
+ {
+ // Replace the 'old IRParam type' with a 'new IRParam type'
+ param->setFullType(flattenedStruct);
+
+ // Emit a new variable at EntryPoint of 'old IRParam type'
+ builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
+ auto dstVal = builder.emitVar(structType);
+ auto dstLoad = builder.emitLoad(dstVal);
+ param->replaceUsesWith(dstLoad);
+ builder.setInsertBefore(dstLoad);
+ // Copy the 'new IRParam type' to our 'old IRParam type'
+ mapOldFieldToNewField
+ .emitCopy<(int)MapStructToFlatStruct::CopyOptions::FlatStructIntoStruct>(
+ builder,
+ dstVal,
+ param);
+
+ modified = true;
+ }
+ }
+ }
+ if (modified)
+ fixUpFuncType(func);
+ }
+
+ void packStageInParameters(EntryPointInfo entryPoint)
+ {
+ // If the entry point has any parameters whose layout contains VaryingInput,
+ // we need to pack those parameters into a single `struct` type, and decorate
+ // the fields with the appropriate `[[attribute]]` decorations.
+ // For other parameters that are not `VaryingInput`, we need to leave them as is.
+ //
+ // For example, given this code after `hoistEntryPointParameterFromStruct`:
+ // ```
+ // void main(float3 pos, float2 uv, int vertexId : SV_VertexID) {
+ // VertexInput vin = {pos,uv,vertexId};
+ // ...
+ // }
+ // ```
+ // We are going to transform it into:
+ // ```
+ // struct VertexInput {
+ // float3 pos [[attribute(0)]];
+ // float2 uv [[attribute(1)]];
+ // };
+ // void main(VertexInput vin, int vertexId : SV_VertexID) {
+ // let pos = vin.pos;
+ // let uv = vin.uv;
+ // ...
+ // }
+
+ auto func = entryPoint.entryPointFunc;
+
+ bool isGeometryStage = false;
+ switch (entryPoint.entryPointDecor->getProfile().getStage())
+ {
+ case Stage::Vertex:
+ case Stage::Amplification:
+ case Stage::Mesh:
+ case Stage::Geometry:
+ case Stage::Domain:
+ case Stage::Hull:
+ isGeometryStage = true;
+ break;
+ }
+
+ List<IRParam*> paramsToPack;
+ for (auto param : func->getParams())
+ {
+ auto layout = findVarLayout(param);
+ if (!layout)
+ continue;
+ if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput))
+ continue;
+ if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration))
+ continue;
+ paramsToPack.add(param);
+ }
+
+ if (paramsToPack.getCount() == 0)
+ return;
+
+ IRBuilder builder(func);
+ builder.setInsertBefore(func);
+ IRStructType* structType = builder.createStructType();
+ auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage());
+ builder.addNameHintDecoration(
+ structType,
+ (String(stageText) + toSlice("Input")).getUnownedSlice());
+ List<IRStructKey*> keys;
+ IRStructTypeLayout::Builder layoutBuilder(&builder);
+ for (auto param : paramsToPack)
+ {
+ auto paramVarLayout = findVarLayout(param);
+ auto key = builder.createStructKey();
+ param->transferDecorationsTo(key);
+ builder.createStructField(structType, key, param->getDataType());
+ if (auto varyingInOffsetAttr =
+ paramVarLayout->findOffsetAttr(LayoutResourceKind::VaryingInput))
+ {
+ if (!key->findDecoration<IRSemanticDecoration>() &&
+ !paramVarLayout->findAttr<IRSemanticAttr>())
+ {
+ // If the parameter doesn't have a semantic, we need to add one for semantic
+ // matching.
+ builder.addSemanticDecoration(
+ key,
+ toSlice("_slang_attr"),
+ (int)varyingInOffsetAttr->getOffset());
+ }
+ }
+
+ if (isGeometryStage)
+ {
+ paramVarLayout = handleGeometryStageParameterVarLayout(builder, paramVarLayout);
+ }
+
+ layoutBuilder.addField(key, paramVarLayout);
+ builder.addLayoutDecoration(key, paramVarLayout);
+ keys.add(key);
+ }
+ builder.setInsertInto(func->getFirstBlock());
+ auto packedParam = builder.emitParamAtHead(structType);
+ auto typeLayout = layoutBuilder.build();
+ IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout);
+
+ // Add a VaryingInput resource info to the packed parameter layout, so that we can emit
+ // the needed `[[stage_in]]` attribute in Metal emitter.
+ varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::VaryingInput);
+ auto paramVarLayout = varLayoutBuilder.build();
+ builder.addLayoutDecoration(packedParam, paramVarLayout);
+
+ // Replace the original parameters with the packed parameter
+ builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
+ for (Index paramIndex = 0; paramIndex < paramsToPack.getCount(); paramIndex++)
+ {
+ auto param = paramsToPack[paramIndex];
+ auto key = keys[paramIndex];
+ auto paramField = builder.emitFieldExtract(param->getDataType(), packedParam, key);
+ param->replaceUsesWith(paramField);
+ param->removeFromParent();
+ }
+ fixUpFuncType(func);
+ }
+
+
+ void reportUnsupportedSystemAttribute(IRInst* param, String semanticName)
+ {
+ m_sink->diagnose(
+ param->sourceLoc,
+ Diagnostics::systemValueAttributeNotSupported,
+ semanticName);
+ }
+
+ template<LayoutResourceKind K>
+ void ensureStructHasUserSemantic(IRStructType* structType, IRVarLayout* varLayout)
+ {
+ // Ensure each field in an output struct type has either a system semantic or a user
+ // semantic, so that signature matching can happen correctly.
+ auto typeLayout = as<IRStructTypeLayout>(varLayout->getTypeLayout());
+ Index index = 0;
+ IRBuilder builder(structType);
+ for (auto field : structType->getFields())
+ {
+ auto key = field->getKey();
+ if (auto semanticDecor = key->findDecoration<IRSemanticDecoration>())
+ {
+ if (semanticDecor->getSemanticName().startsWithCaseInsensitive(toSlice("sv_")))
+ {
+ auto indexAsString = String(UInt(semanticDecor->getSemanticIndex()));
+ auto sysValInfo =
+ getSystemValueInfo(semanticDecor->getSemanticName(), &indexAsString, field);
+ if (sysValInfo.isUnsupported)
+ {
+ reportUnsupportedSystemAttribute(field, semanticDecor->getSemanticName());
+ }
+ else
+ {
+ builder.addTargetSystemValueDecoration(
+ key,
+ sysValInfo.systemValueName.getUnownedSlice());
+ semanticDecor->removeAndDeallocate();
+ }
+ }
+ index++;
+ continue;
+ }
+ typeLayout->getFieldLayout(index);
+ auto fieldLayout = typeLayout->getFieldLayout(index);
+ if (auto offsetAttr = fieldLayout->findOffsetAttr(K))
+ {
+ UInt varOffset = 0;
+ if (auto varOffsetAttr = varLayout->findOffsetAttr(K))
+ varOffset = varOffsetAttr->getOffset();
+ varOffset += offsetAttr->getOffset();
+ builder.addSemanticDecoration(key, toSlice("_slang_attr"), (int)varOffset);
+ }
+ index++;
+ }
+ }
+
+ // Stores a hicharchy of members and children which map 'oldStruct->member' to
+ // 'flatStruct->member' Note: this map assumes we map to FlatStruct since it is easier/faster to
+ // process
+ struct MapStructToFlatStruct
+ {
+ /*
+ We need a hicharchy map to resolve dependencies for mapping
+ oldStruct to newStruct efficently. Example:
+
+ MyStruct
+ |
+ / | \
+ / | \
+ / | \
+ M0<A> M1<A> M2<B>
+ | | |
+ A_0 A_0 B_0
+
+ Without storing hicharchy information, there will be no way to tell apart
+ `myStruct.M0.A0` from `myStruct.M1.A0` since IRStructKey/IRStructField
+ only has 1 instance of `A::A0`
+ */
+
+ enum CopyOptions : int
+ {
+ // Copy a flattened-struct into a struct
+ FlatStructIntoStruct = 0,
+
+ // Copy a struct into a flattened-struct
+ StructIntoFlatStruct = 1,
+ };
+
+ private:
+ // Children of member if applicable.
+ Dictionary<IRStructField*, MapStructToFlatStruct> members;
+
+ // Field correlating to MapStructToFlatStruct Node.
+ IRInst* node;
+ IRStructKey* getKey()
+ {
+ SLANG_ASSERT(as<IRStructField>(node));
+ return as<IRStructField>(node)->getKey();
+ }
+ IRInst* getNode() { return node; }
+ IRType* getFieldType()
+ {
+ SLANG_ASSERT(as<IRStructField>(node));
+ return as<IRStructField>(node)->getFieldType();
+ }
+
+ // Whom node maps to inside target flatStruct
+ IRStructField* targetMapping;
+
+ auto begin() { return members.begin(); }
+ auto end() { return members.end(); }
+
+ // Copies members of oldStruct to/from newFlatStruct. Assumes members of val1 maps to
+ // members in val2 using `MapStructToFlatStruct`
+ template<int copyOptions>
+ static void _emitCopy(
+ IRBuilder& builder,
+ IRInst* val1,
+ IRStructType* type1,
+ IRInst* val2,
+ IRStructType* type2,
+ MapStructToFlatStruct& node)
+ {
+ for (auto& field1Pair : node)
+ {
+ auto& field1 = field1Pair.second;
+
+ // Get member of val1
+ IRInst* fieldAddr1 = nullptr;
+ if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct)
+ {
+ fieldAddr1 = builder.emitFieldAddress(type1, val1, field1.getKey());
+ }
+ else
+ {
+ if (as<IRPtrTypeBase>(val1))
+ val1 = builder.emitLoad(val1);
+ fieldAddr1 = builder.emitFieldExtract(type1, val1, field1.getKey());
+ }
+
+ // If val1 is a struct, recurse
+ if (auto fieldAsStruct1 = as<IRStructType>(field1.getFieldType()))
+ {
+ _emitCopy<copyOptions>(
+ builder,
+ fieldAddr1,
+ fieldAsStruct1,
+ val2,
+ type2,
+ field1);
+ continue;
+ }
+
+ // Get member of val2 which maps to val1.member
+ auto field2 = field1.getMapping();
+ SLANG_ASSERT(field2);
+ IRInst* fieldAddr2 = nullptr;
+ if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct)
+ {
+ if (as<IRPtrTypeBase>(val2))
+ val2 = builder.emitLoad(val1);
+ fieldAddr2 = builder.emitFieldExtract(type2, val2, field2->getKey());
+ }
+ else
+ {
+ fieldAddr2 = builder.emitFieldAddress(type2, val2, field2->getKey());
+ }
+
+ // Copy val2/val1 member into val1/val2 member
+ if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct)
+ {
+ builder.emitStore(fieldAddr1, fieldAddr2);
+ }
+ else
+ {
+ builder.emitStore(fieldAddr2, fieldAddr1);
+ }
+ }
+ }
+
+ public:
+ void setNode(IRInst* newNode) { node = newNode; }
+ // Get 'MapStructToFlatStruct' that is a child of 'parent'.
+ // Make 'MapStructToFlatStruct' if no 'member' is currently mapped to 'parent'.
+ MapStructToFlatStruct& getMember(IRStructField* member) { return members[member]; }
+ MapStructToFlatStruct& operator[](IRStructField* member) { return getMember(member); }
+
+ void setMapping(IRStructField* newTargetMapping) { targetMapping = newTargetMapping; }
+ // Get 'MapStructToFlatStruct' that is a child of 'parent'.
+ // Return nullptr if no member is mapped to 'parent'
+ IRStructField* getMapping() { return targetMapping; }
+
+ // Copies srcVal into dstVal using hicharchy map.
+ template<int copyOptions>
+ void emitCopy(IRBuilder& builder, IRInst* dstVal, IRInst* srcVal)
+ {
+ auto dstType = dstVal->getDataType();
+ if (auto dstPtrType = as<IRPtrTypeBase>(dstType))
+ dstType = dstPtrType->getValueType();
+ auto dstStructType = as<IRStructType>(dstType);
+ SLANG_ASSERT(dstStructType);
+
+ auto srcType = srcVal->getDataType();
+ if (auto srcPtrType = as<IRPtrTypeBase>(srcType))
+ srcType = srcPtrType->getValueType();
+ auto srcStructType = as<IRStructType>(srcType);
+ SLANG_ASSERT(srcStructType);
+
+ if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct)
+ {
+ // CopyOptions::FlatStructIntoStruct copy a flattened-struct (mapped member) into a
+ // struct
+ SLANG_ASSERT(node == dstStructType);
+ _emitCopy<copyOptions>(
+ builder,
+ dstVal,
+ dstStructType,
+ srcVal,
+ srcStructType,
+ *this);
+ }
+ else
+ {
+ // CopyOptions::StructIntoFlatStruct copy a struct into a flattened-struct
+ SLANG_ASSERT(node == srcStructType);
+ _emitCopy<copyOptions>(
+ builder,
+ srcVal,
+ srcStructType,
+ dstVal,
+ dstStructType,
+ *this);
+ }
+ }
+ };
+
+ IRStructType* _flattenNestedStructs(
+ IRBuilder& builder,
+ IRStructType* dst,
+ IRStructType* src,
+ IRSemanticDecoration* parentSemanticDecoration,
+ IRLayoutDecoration* parentLayout,
+ MapStructToFlatStruct& mapFieldToField,
+ HashSet<IRStructField*>& varsWithSemanticInfo)
+ {
+ // For all fields ('oldField') of a struct do the following:
+ // 1. Check for 'decorations which carry semantic info' (IRSemanticDecoration,
+ // IRLayoutDecoration), store these if found.
+ // * Do not propagate semantic info if the current node has *any* form of semantic
+ // information.
+ // Update varsWithSemanticInfo.
+ // 2. If IRStructType:
+ // 2a. Recurse this function with 'decorations that carry semantic info' from parent.
+ // 3. If not IRStructType:
+ // 3a Metal. Emit 'newField' equal to 'oldField', add 'decorations which carry semantic
+ // info'.
+ //
+ // 3a WGSL. Emit 'newField' with 'newKey' equal to 'oldField' and 'oldKey', respectively,
+ // where 'oldKey' is the key corresponding to 'oldField'.
+ // Add 'decorations which carry semantic info' to 'newField', and move all decorations
+ // of 'oldKey' to 'newKey'.
+ // 3b. Store a mapping from 'oldField' to 'newField' in 'mapFieldToField'. This info is
+ // needed to copy between types.
+ for (auto oldField : src->getFields())
+ {
+ auto& fieldMappingNode = mapFieldToField[oldField];
+ fieldMappingNode.setNode(oldField);
+
+ // step 1
+ bool foundSemanticDecor = false;
+ auto oldKey = oldField->getKey();
+ IRSemanticDecoration* fieldSemanticDecoration = parentSemanticDecoration;
+ if (auto oldSemanticDecoration = oldKey->findDecoration<IRSemanticDecoration>())
+ {
+ foundSemanticDecor = true;
+ fieldSemanticDecoration = oldSemanticDecoration;
+ parentLayout = nullptr;
+ }
+
+ IRLayoutDecoration* fieldLayout = parentLayout;
+ if (auto oldLayout = oldKey->findDecoration<IRLayoutDecoration>())
+ {
+ fieldLayout = oldLayout;
+ if (!foundSemanticDecor)
+ fieldSemanticDecoration = nullptr;
+ }
+ if (fieldSemanticDecoration != parentSemanticDecoration || parentLayout != fieldLayout)
+ varsWithSemanticInfo.add(oldField);
+
+ // step 2a
+ if (auto structFieldType = as<IRStructType>(oldField->getFieldType()))
+ {
+ _flattenNestedStructs(
+ builder,
+ dst,
+ structFieldType,
+ fieldSemanticDecoration,
+ fieldLayout,
+ fieldMappingNode,
+ varsWithSemanticInfo);
+ continue;
+ }
+
+ // step 3a
+ auto newKey = builder.createStructKey();
+ flattenNestedStructsTransferKeyDecorations(newKey, oldKey);
+
+ auto newField = builder.createStructField(dst, newKey, oldField->getFieldType());
+ copyNameHintAndDebugDecorations(newField, oldField);
+
+ if (fieldSemanticDecoration)
+ builder.addSemanticDecoration(
+ newKey,
+ fieldSemanticDecoration->getSemanticName(),
+ fieldSemanticDecoration->getSemanticIndex());
+
+ if (fieldLayout)
+ {
+ IRLayout* oldLayout = fieldLayout->getLayout();
+ List<IRInst*> instToCopy;
+ // Only copy certain decorations needed for resolving system semantics
+ for (UInt i = 0; i < oldLayout->getOperandCount(); i++)
+ {
+ auto operand = oldLayout->getOperand(i);
+ if (as<IRVarOffsetAttr>(operand) || as<IRUserSemanticAttr>(operand) ||
+ as<IRSystemValueSemanticAttr>(operand) || as<IRStageAttr>(operand))
+ instToCopy.add(operand);
+ }
+ IRVarLayout* newLayout = builder.getVarLayout(instToCopy);
+ builder.addLayoutDecoration(newKey, newLayout);
+ }
+ // step 3b
+ fieldMappingNode.setMapping(newField);
+ }
+
+ return dst;
+ }
+
+ // Returns a `IRStructType*` without any `IRStructType*` members. `src` may be returned if there
+ // was no struct flattening.
+ // @param mapFieldToField Behavior maps all `IRStructField` of `src` to the new struct
+ // `IRStructFields`s
+ IRStructType* maybeFlattenNestedStructs(
+ IRBuilder& builder,
+ IRStructType* src,
+ MapStructToFlatStruct& mapFieldToField,
+ HashSet<IRStructField*>& varsWithSemanticInfo)
+ {
+ // Find all values inside struct that need flattening and legalization.
+ bool hasStructTypeMembers = false;
+ for (auto field : src->getFields())
+ {
+ if (as<IRStructType>(field->getFieldType()))
+ {
+ hasStructTypeMembers = true;
+ break;
+ }
+ }
+ if (!hasStructTypeMembers)
+ return src;
+
+ // We need to:
+ // 1. Make new struct 1:1 with old struct but without nestested structs (flatten)
+ // 2. Ensure semantic attributes propegate. This will create overlapping semantics (can be
+ // handled later).
+ // 3. Store the mapping from old to new struct fields to allow copying a old-struct to
+ // new-struct.
+ builder.setInsertAfter(src);
+ auto newStruct = builder.createStructType();
+ copyNameHintAndDebugDecorations(newStruct, src);
+ mapFieldToField.setNode(src);
+ return _flattenNestedStructs(
+ builder,
+ newStruct,
+ src,
+ nullptr,
+ nullptr,
+ mapFieldToField,
+ varsWithSemanticInfo);
+ }
+
+ // Replaces all 'IRReturn' by copying the current 'IRReturn' to a new var of type 'newType'.
+ // Copying logic from 'IRReturn' to 'newType' is controlled by 'copyLogicFunc' function.
+ template<typename CopyLogicFunc>
+ void _replaceAllReturnInst(
+ IRBuilder& builder,
+ IRFunc* targetFunc,
+ IRStructType* newType,
+ CopyLogicFunc copyLogicFunc)
+ {
+ for (auto block : targetFunc->getBlocks())
+ {
+ if (auto returnInst = as<IRReturn>(block->getTerminator()))
+ {
+ builder.setInsertBefore(returnInst);
+ auto returnVal = returnInst->getVal();
+ returnInst->setOperand(0, copyLogicFunc(builder, newType, returnVal));
+ }
+ }
+ }
+
+ UInt _returnNonOverlappingAttributeIndex(std::set<UInt>& usedSemanticIndex)
+ {
+ // Find first unused semantic index of equal semantic type
+ // to fill any gaps in user set semantic bindings
+ UInt prev = 0;
+ for (auto i : usedSemanticIndex)
+ {
+ if (i > prev + 1)
+ {
+ break;
+ }
+ prev = i;
+ }
+ usedSemanticIndex.insert(prev + 1);
+ return prev + 1;
+ }
+
+ template<typename T>
+ struct AttributeParentPair
+ {
+ IRLayoutDecoration* layoutDecor;
+ T* attr;
+ };
+
+ IRLayoutDecoration* _replaceAttributeOfLayout(
+ IRBuilder& builder,
+ IRLayoutDecoration* parentLayoutDecor,
+ IRInst* instToReplace,
+ IRInst* instToReplaceWith)
+ {
+ // Replace `instToReplace` with a `instToReplaceWith`
+
+ auto layout = parentLayoutDecor->getLayout();
+ // Find the exact same decoration `instToReplace` in-case multiple of the same type exist
+ List<IRInst*> opList;
+ opList.add(instToReplaceWith);
+ for (UInt i = 0; i < layout->getOperandCount(); i++)
+ {
+ if (layout->getOperand(i) != instToReplace)
+ opList.add(layout->getOperand(i));
+ }
+ auto newLayoutDecor = builder.addLayoutDecoration(
+ parentLayoutDecor->getParent(),
+ builder.getVarLayout(opList));
+ parentLayoutDecor->removeAndDeallocate();
+ return newLayoutDecor;
+ }
+
+ IRLayoutDecoration* _simplifyUserSemanticNames(
+ IRBuilder& builder,
+ IRLayoutDecoration* layoutDecor)
+ {
+ // Ensure all 'ExplicitIndex' semantics such as "SV_TARGET0" are simplified into
+ // ("SV_TARGET", 0) using 'IRUserSemanticAttr' This is done to ensure we can check semantic
+ // groups using 'IRUserSemanticAttr1->getName() == IRUserSemanticAttr2->getName()'
+ SLANG_ASSERT(layoutDecor);
+ auto layout = layoutDecor->getLayout();
+ List<IRInst*> layoutOps;
+ layoutOps.reserve(3);
+ bool changed = false;
+ for (auto attr : layout->getAllAttrs())
+ {
+ if (auto userSemantic = as<IRUserSemanticAttr>(attr))
+ {
+ UnownedStringSlice outName;
+ UnownedStringSlice outIndex;
+ bool hasStringIndex = splitNameAndIndex(userSemantic->getName(), outName, outIndex);
+ if (hasStringIndex)
+ {
+ changed = true;
+ auto loweredName = String(outName).toLower();
+ auto loweredNameSlice = loweredName.getUnownedSlice();
+ auto newDecoration =
+ builder.getUserSemanticAttr(loweredNameSlice, stringToInt(outIndex));
+ userSemantic->replaceUsesWith(newDecoration);
+ userSemantic->removeAndDeallocate();
+ userSemantic = newDecoration;
+ }
+ layoutOps.add(userSemantic);
+ continue;
+ }
+ layoutOps.add(attr);
+ }
+ if (changed)
+ {
+ auto parent = layoutDecor->parent;
+ layoutDecor->removeAndDeallocate();
+ builder.addLayoutDecoration(parent, builder.getVarLayout(layoutOps));
+ }
+ return layoutDecor;
+ }
+
+ // Find overlapping field semantics and legalize them
+ void fixFieldSemanticsOfFlatStruct(IRStructType* structType)
+ {
+ // Goal is to ensure we do not have overlapping semantics for the user defined semantics:
+ // Note that in WGSL, the semantics can be either `builtin` without index or `location` with
+ // index.
+ /*
+ // Assume the following code
+ struct Fragment
+ {
+ float4 p0 : SV_POSITION;
+ float2 p1 : TEXCOORD0;
+ float2 p2 : TEXCOORD1;
+ float3 p3 : COLOR0;
+ float3 p4 : COLOR1;
+ };
+
+ // Translates into
+ struct Fragment
+ {
+ float4 p0 : BUILTIN_POSITION;
+ float2 p1 : LOCATION_0;
+ float2 p2 : LOCATION_1;
+ float3 p3 : LOCATION_2;
+ float3 p4 : LOCATION_3;
+ };
+ */
+
+ // For Multi-Render-Target, the semantic index must be translated to `location` with
+ // the same index. Assume the following code
+ /*
+ struct Fragment
+ {
+ float4 p0 : SV_TARGET1;
+ float4 p1 : SV_TARGET0;
+ };
+
+ // Translates into
+ struct Fragment
+ {
+ float4 p0 : LOCATION_1;
+ float4 p1 : LOCATION_0;
+ };
+ */
+
+ IRBuilder builder(this->m_module);
+
+ List<IRSemanticDecoration*> overlappingSemanticsDecor;
+ Dictionary<UnownedStringSlice, std::set<UInt, std::less<UInt>>>
+ usedSemanticIndexSemanticDecor;
+
+ List<AttributeParentPair<IRVarOffsetAttr>> overlappingVarOffset;
+ Dictionary<UInt, std::set<UInt, std::less<UInt>>> usedSemanticIndexVarOffset;
+
+ List<AttributeParentPair<IRUserSemanticAttr>> overlappingUserSemantic;
+ Dictionary<UnownedStringSlice, std::set<UInt, std::less<UInt>>>
+ usedSemanticIndexUserSemantic;
+
+ // We store a map from old `IRLayoutDecoration*` to new `IRLayoutDecoration*` since when
+ // legalizing we may destroy and remake a `IRLayoutDecoration*`
+ Dictionary<IRLayoutDecoration*, IRLayoutDecoration*> oldLayoutDecorToNew;
+
+ // Collect all "semantic info carrying decorations". Any collected decoration will
+ // fill up their respective 'Dictionary<SEMANTIC_TYPE, OrderedHashSet<UInt>>'
+ // to keep track of in-use offsets for a semantic type.
+ // Example: IRSemanticDecoration with name of "SV_TARGET1".
+ // * This will have SEMANTIC_TYPE of "sv_target".
+ // * This will use up index '1'
+ //
+ // Now if a second equal semantic "SV_TARGET1" is found, we add this decoration to
+ // a list of 'overlapping semantic info decorations' so we can legalize this
+ // 'semantic info decoration' later.
+ //
+ // NOTE: this is a flat struct, all members are children of the initial
+ // IRStructType.
+ for (auto field : structType->getFields())
+ {
+ auto key = field->getKey();
+ if (auto semanticDecoration = key->findDecoration<IRSemanticDecoration>())
+ {
+ auto semanticName = semanticDecoration->getSemanticName();
+
+ // sv_target is treated as a user-semantic because it should be emitted with
+ // @location like how the user semantics are emitted.
+ // For fragment shader, only sv_target will user @location, and for non-fragment
+ // shaders, sv_target is not valid.
+ bool isUserSemantic =
+ (semanticName.startsWithCaseInsensitive(toSlice("sv_target")) ||
+ !semanticName.startsWithCaseInsensitive(toSlice("sv_")));
+
+ // Ensure names are in a uniform lowercase format so we can bunch together simmilar
+ // semantics.
+ UnownedStringSlice outName;
+ UnownedStringSlice outIndex;
+ bool hasStringIndex = splitNameAndIndex(semanticName, outName, outIndex);
+
+ auto loweredName = String(outName).toLower();
+ auto loweredNameSlice = getUserSemanticNameSlice(loweredName, isUserSemantic);
+ auto semanticIndex =
+ hasStringIndex ? stringToInt(outIndex) : semanticDecoration->getSemanticIndex();
+ auto newDecoration =
+ builder.addSemanticDecoration(key, loweredNameSlice, semanticIndex);
+
+ semanticDecoration->replaceUsesWith(newDecoration);
+ semanticDecoration->removeAndDeallocate();
+ semanticDecoration = newDecoration;
+
+ auto& semanticUse =
+ usedSemanticIndexSemanticDecor[semanticDecoration->getSemanticName()];
+ if (semanticUse.find(semanticDecoration->getSemanticIndex()) != semanticUse.end())
+ overlappingSemanticsDecor.add(semanticDecoration);
+ else
+ semanticUse.insert(semanticDecoration->getSemanticIndex());
+ }
+ if (auto layoutDecor = key->findDecoration<IRLayoutDecoration>())
+ {
+ // Ensure names are in a uniform lowercase format so we can bunch together simmilar
+ // semantics
+ layoutDecor = _simplifyUserSemanticNames(builder, layoutDecor);
+ oldLayoutDecorToNew[layoutDecor] = layoutDecor;
+ auto layout = layoutDecor->getLayout();
+ for (auto attr : layout->getAllAttrs())
+ {
+ if (auto offset = as<IRVarOffsetAttr>(attr))
+ {
+ auto& semanticUse = usedSemanticIndexVarOffset[offset->getResourceKind()];
+ if (semanticUse.find(offset->getOffset()) != semanticUse.end())
+ overlappingVarOffset.add({layoutDecor, offset});
+ else
+ semanticUse.insert(offset->getOffset());
+ }
+ else if (auto userSemantic = as<IRUserSemanticAttr>(attr))
+ {
+ auto& semanticUse = usedSemanticIndexUserSemantic[userSemantic->getName()];
+ if (semanticUse.find(userSemantic->getIndex()) != semanticUse.end())
+ overlappingUserSemantic.add({layoutDecor, userSemantic});
+ else
+ semanticUse.insert(userSemantic->getIndex());
+ }
+ }
+ }
+ }
+
+ // Legalize all overlapping 'semantic info decorations'
+ for (auto decor : overlappingSemanticsDecor)
+ {
+ auto newOffset = _returnNonOverlappingAttributeIndex(
+ usedSemanticIndexSemanticDecor[decor->getSemanticName()]);
+ builder.addSemanticDecoration(
+ decor->getParent(),
+ decor->getSemanticName(),
+ (int)newOffset);
+ decor->removeAndDeallocate();
+ }
+ for (auto& varOffset : overlappingVarOffset)
+ {
+ auto newOffset = _returnNonOverlappingAttributeIndex(
+ usedSemanticIndexVarOffset[varOffset.attr->getResourceKind()]);
+ auto newVarOffset = builder.getVarOffsetAttr(
+ varOffset.attr->getResourceKind(),
+ newOffset,
+ varOffset.attr->getSpace());
+ oldLayoutDecorToNew[varOffset.layoutDecor] = _replaceAttributeOfLayout(
+ builder,
+ oldLayoutDecorToNew[varOffset.layoutDecor],
+ varOffset.attr,
+ newVarOffset);
+ }
+ for (auto& userSemantic : overlappingUserSemantic)
+ {
+ auto newOffset = _returnNonOverlappingAttributeIndex(
+ usedSemanticIndexUserSemantic[userSemantic.attr->getName()]);
+ auto newUserSemantic =
+ builder.getUserSemanticAttr(userSemantic.attr->getName(), newOffset);
+ oldLayoutDecorToNew[userSemantic.layoutDecor] = _replaceAttributeOfLayout(
+ builder,
+ oldLayoutDecorToNew[userSemantic.layoutDecor],
+ userSemantic.attr,
+ newUserSemantic);
+ }
+ }
+
+ void wrapReturnValueInStruct(EntryPointInfo entryPoint)
+ {
+ // Wrap return value into a struct if it is not already a struct.
+ // For example, given this entry point:
+ // ```
+ // float4 main() : SV_Target { return float3(1,2,3); }
+ // ```
+ // We are going to transform it into:
+ // ```
+ // struct Output {
+ // float4 value : SV_Target;
+ // };
+ // Output main() { return {float3(1,2,3)}; }
+
+ auto func = entryPoint.entryPointFunc;
+
+ auto returnType = func->getResultType();
+ if (as<IRVoidType>(returnType))
+ return;
+ auto entryPointLayoutDecor = func->findDecoration<IRLayoutDecoration>();
+ if (!entryPointLayoutDecor)
+ return;
+ auto entryPointLayout = as<IREntryPointLayout>(entryPointLayoutDecor->getLayout());
+ if (!entryPointLayout)
+ return;
+ auto resultLayout = entryPointLayout->getResultLayout();
+
+ // If return type is already a struct, just make sure every field has a semantic.
+ if (auto returnStructType = as<IRStructType>(returnType))
+ {
+ IRBuilder builder(func);
+ MapStructToFlatStruct mapOldFieldToNewField;
+ // Flatten result struct type to ensure we do not have nested semantics
+ auto flattenedStruct = maybeFlattenNestedStructs(
+ builder,
+ returnStructType,
+ mapOldFieldToNewField,
+ semanticInfoToRemove);
+ if (returnStructType != flattenedStruct)
+ {
+ // Replace all return-values with the flattenedStruct we made.
+ _replaceAllReturnInst(
+ builder,
+ func,
+ flattenedStruct,
+ [&](IRBuilder& copyBuilder, IRStructType* dstType, IRInst* srcVal) -> IRInst*
+ {
+ auto srcStructType = as<IRStructType>(srcVal->getDataType());
+ SLANG_ASSERT(srcStructType);
+ auto dstVal = copyBuilder.emitVar(dstType);
+ mapOldFieldToNewField.emitCopy<(
+ int)MapStructToFlatStruct::CopyOptions::StructIntoFlatStruct>(
+ copyBuilder,
+ dstVal,
+ srcVal);
+ return builder.emitLoad(dstVal);
+ });
+ fixUpFuncType(func, flattenedStruct);
+ }
+ // Ensure non-overlapping semantics
+ fixFieldSemanticsOfFlatStruct(flattenedStruct);
+ ensureStructHasUserSemantic<LayoutResourceKind::VaryingOutput>(
+ flattenedStruct,
+ resultLayout);
+ return;
+ }
+
+ IRBuilder builder(func);
+ builder.setInsertBefore(func);
+ IRStructType* structType = builder.createStructType();
+ auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage());
+ builder.addNameHintDecoration(
+ structType,
+ (String(stageText) + toSlice("Output")).getUnownedSlice());
+ auto key = builder.createStructKey();
+ builder.addNameHintDecoration(key, toSlice("output"));
+ builder.addLayoutDecoration(key, resultLayout);
+ builder.createStructField(structType, key, returnType);
+ IRStructTypeLayout::Builder structTypeLayoutBuilder(&builder);
+ structTypeLayoutBuilder.addField(key, resultLayout);
+ auto typeLayout = structTypeLayoutBuilder.build();
+ IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout);
+ auto varLayout = varLayoutBuilder.build();
+ ensureStructHasUserSemantic<LayoutResourceKind::VaryingOutput>(structType, varLayout);
+
+ _replaceAllReturnInst(
+ builder,
+ func,
+ structType,
+ [](IRBuilder& copyBuilder, IRStructType* dstType, IRInst* srcVal) -> IRInst*
+ { return copyBuilder.emitMakeStruct(dstType, 1, &srcVal); });
+
+ // Assign an appropriate system value semantic for stage output
+ auto stage = entryPoint.entryPointDecor->getProfile().getStage();
+ switch (stage)
+ {
+ case Stage::Compute:
+ case Stage::Fragment:
+ {
+ addFragmentShaderReturnValueDecoration(builder, key);
+ break;
+ }
+ case Stage::Vertex:
+ {
+ builder.addTargetSystemValueDecoration(key, toSlice("position"));
+ break;
+ }
+ default:
+ SLANG_ASSERT(false);
+ return;
+ }
+
+ fixUpFuncType(func, structType);
+ }
+
+ IRInst* tryConvertValue(IRBuilder& builder, IRInst* val, IRType* toType)
+ {
+ auto fromType = val->getFullType();
+ if (auto fromVector = as<IRVectorType>(fromType))
+ {
+ if (auto toVector = as<IRVectorType>(toType))
+ {
+ if (fromVector->getElementCount() != toVector->getElementCount())
+ {
+ fromType = builder.getVectorType(
+ fromVector->getElementType(),
+ toVector->getElementCount());
+ val = builder.emitVectorReshape(fromType, val);
+ }
+ }
+ else if (as<IRBasicType>(toType))
+ {
+ UInt index = 0;
+ val = builder.emitSwizzle(fromVector->getElementType(), val, 1, &index);
+ if (toType->getOp() == kIROp_VoidType)
+ return nullptr;
+ }
+ }
+ else if (auto fromBasicType = as<IRBasicType>(fromType))
+ {
+ if (fromBasicType->getOp() == kIROp_VoidType)
+ return nullptr;
+ if (!as<IRBasicType>(toType))
+ return nullptr;
+ if (toType->getOp() == kIROp_VoidType)
+ return nullptr;
+ }
+ else
+ {
+ return nullptr;
+ }
+ return builder.emitCast(toType, val);
+ }
+
+ void legalizeSystemValueParameters(EntryPointInfo entryPoint)
+ {
+ List<SystemValLegalizationWorkItem> systemValWorkItems =
+ collectSystemValFromEntryPoint(entryPoint);
+
+ for (auto index = 0; index < systemValWorkItems.getCount(); index++)
+ {
+ legalizeSystemValue(entryPoint, systemValWorkItems[index]);
+ }
+ fixUpFuncType(entryPoint.entryPointFunc);
+ }
+
+ void legalizeEntryPoint(EntryPointInfo entryPoint)
+ {
+ // If the entrypoint is receiving varying inputs as a pointer, turn it into a value.
+ depointerizeInputParams(entryPoint.entryPointFunc);
+
+ // TODO FIXME: Enable these for WGSL and remove the `hoistParemeters` member field.
+ // WGSL entry point legalization currently only applies attributes to struct parameters,
+ // apply the same hoisting from Metal to WGSL to fix it.
+ if (hoistParameters)
+ {
+ hoistEntryPointParameterFromStruct(entryPoint);
+ packStageInParameters(entryPoint);
+ }
+
+ // Input Parameter Legalize
+ flattenInputParameters(entryPoint);
+
+ // System Value Legalize
+ legalizeSystemValueParameters(entryPoint);
+
+ // Output Value Legalize
+ wrapReturnValueInStruct(entryPoint);
+
+
+ // Other Legalize
+ switch (entryPoint.entryPointDecor->getProfile().getStage())
+ {
+ case Stage::Amplification:
+ legalizeAmplificationStageEntryPoint(entryPoint);
+ break;
+ case Stage::Mesh:
+ legalizeMeshStageEntryPoint(entryPoint);
+ break;
+ default:
+ break;
+ }
+ }
+};
+
+class LegalizeMetalEntryPointContext : public LegalizeShaderEntryPointContext
+{
+public:
+ LegalizeMetalEntryPointContext(IRModule* module, DiagnosticSink* sink)
+ : LegalizeShaderEntryPointContext(module, sink, true)
+ {
+ generatePermittedTypes_sv_target();
+ }
+
+protected:
+ SystemValueInfo getSystemValueInfo(
+ String inSemanticName,
+ String* optionalSemanticIndex,
+ IRInst* parentVar) const SLANG_OVERRIDE
+ {
+ IRBuilder builder(m_module);
+ SystemValueInfo result = {};
+ UnownedStringSlice semanticName;
+ UnownedStringSlice semanticIndex;
+
+ auto hasExplicitIndex =
+ splitNameAndIndex(inSemanticName.getUnownedSlice(), semanticName, semanticIndex);
+ if (!hasExplicitIndex && optionalSemanticIndex)
+ semanticIndex = optionalSemanticIndex->getUnownedSlice();
+
+ result.systemValueNameEnum = convertSystemValueSemanticNameToEnum(semanticName);
+
+ switch (result.systemValueNameEnum)
+ {
+ case SystemValueSemanticName::Position:
+ {
+ result.systemValueName = toSlice("position");
+ result.permittedTypes.add(builder.getVectorType(
+ builder.getBasicType(BaseType::Float),
+ builder.getIntValue(builder.getIntType(), 4)));
+ break;
+ }
+ case SystemValueSemanticName::ClipDistance:
+ {
+ result.isSpecial = true;
+ break;
+ }
+ case SystemValueSemanticName::CullDistance:
+ {
+ result.isSpecial = true;
+ break;
+ }
+ case SystemValueSemanticName::Coverage:
+ {
+ result.systemValueName = toSlice("sample_mask");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ break;
+ }
+ case SystemValueSemanticName::InnerCoverage:
+ {
+ result.isSpecial = true;
+ break;
+ }
+ case SystemValueSemanticName::Depth:
+ {
+ result.systemValueName = toSlice("depth(any)");
+ result.permittedTypes.add(builder.getBasicType(BaseType::Float));
+ break;
+ }
+ case SystemValueSemanticName::DepthGreaterEqual:
+ {
+ result.systemValueName = toSlice("depth(greater)");
+ result.permittedTypes.add(builder.getBasicType(BaseType::Float));
+ break;
+ }
+ case SystemValueSemanticName::DepthLessEqual:
+ {
+ result.systemValueName = toSlice("depth(less)");
+ result.permittedTypes.add(builder.getBasicType(BaseType::Float));
+ break;
+ }
+ case SystemValueSemanticName::DispatchThreadID:
+ {
+ result.systemValueName = toSlice("thread_position_in_grid");
+ result.permittedTypes.add(builder.getVectorType(
+ builder.getBasicType(BaseType::UInt),
+ builder.getIntValue(builder.getIntType(), 3)));
+ break;
+ }
+ case SystemValueSemanticName::DomainLocation:
+ {
+ result.systemValueName = toSlice("position_in_patch");
+ result.permittedTypes.add(builder.getVectorType(
+ builder.getBasicType(BaseType::Float),
+ builder.getIntValue(builder.getIntType(), 3)));
+ result.permittedTypes.add(builder.getVectorType(
+ builder.getBasicType(BaseType::Float),
+ builder.getIntValue(builder.getIntType(), 2)));
+ break;
+ }
+ case SystemValueSemanticName::GroupID:
+ {
+ result.systemValueName = toSlice("threadgroup_position_in_grid");
+ result.permittedTypes.add(builder.getVectorType(
+ builder.getBasicType(BaseType::UInt),
+ builder.getIntValue(builder.getIntType(), 3)));
+ break;
+ }
+ case SystemValueSemanticName::GroupIndex:
+ {
+ result.isSpecial = true;
+ break;
+ }
+ case SystemValueSemanticName::GroupThreadID:
+ {
+ result.systemValueName = toSlice("thread_position_in_threadgroup");
+ result.permittedTypes.add(getGroupThreadIdType(builder));
+ break;
+ }
+ case SystemValueSemanticName::GSInstanceID:
+ {
+ result.isUnsupported = true;
+ break;
+ }
+ case SystemValueSemanticName::InstanceID:
+ {
+ result.systemValueName = toSlice("instance_id");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ break;
+ }
+ case SystemValueSemanticName::IsFrontFace:
+ {
+ result.systemValueName = toSlice("front_facing");
+ result.permittedTypes.add(builder.getBasicType(BaseType::Bool));
+ break;
+ }
+ case SystemValueSemanticName::OutputControlPointID:
+ {
+ // In metal, a hull shader is just a compute shader.
+ // This needs to be handled separately, by lowering into an ordinary buffer.
+ break;
+ }
+ case SystemValueSemanticName::PointSize:
+ {
+ result.systemValueName = toSlice("point_size");
+ result.permittedTypes.add(builder.getBasicType(BaseType::Float));
+ break;
+ }
+ case SystemValueSemanticName::PrimitiveID:
+ {
+ result.systemValueName = toSlice("primitive_id");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt16));
+ break;
+ }
+ case SystemValueSemanticName::RenderTargetArrayIndex:
+ {
+ result.systemValueName = toSlice("render_target_array_index");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt16));
+ break;
+ }
+ case SystemValueSemanticName::SampleIndex:
+ {
+ result.systemValueName = toSlice("sample_id");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ break;
+ }
+ case SystemValueSemanticName::StencilRef:
+ {
+ result.systemValueName = toSlice("stencil");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ break;
+ }
+ case SystemValueSemanticName::TessFactor:
+ {
+ // Tessellation factor outputs should be lowered into a write into a normal buffer.
+ break;
+ }
+ case SystemValueSemanticName::VertexID:
+ {
+ result.systemValueName = toSlice("vertex_id");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ break;
+ }
+ case SystemValueSemanticName::ViewID:
+ {
+ result.isUnsupported = true;
+ break;
+ }
+ case SystemValueSemanticName::ViewportArrayIndex:
+ {
+ result.systemValueName = toSlice("viewport_array_index");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt16));
+ break;
+ }
+ case SystemValueSemanticName::Target:
+ {
+ result.systemValueName =
+ (StringBuilder()
+ << "color(" << (semanticIndex.getLength() != 0 ? semanticIndex : toSlice("0"))
+ << ")")
+ .produceString();
+ result.permittedTypes = permittedTypes_sv_target;
+
+ break;
+ }
+ case SystemValueSemanticName::StartVertexLocation:
+ {
+ result.systemValueName = toSlice("base_vertex");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ break;
+ }
+ case SystemValueSemanticName::StartInstanceLocation:
+ {
+ result.systemValueName = toSlice("base_instance");
+ result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
+ break;
+ }
+ default:
+ m_sink->diagnose(
+ parentVar,
+ Diagnostics::unimplementedSystemValueSemantic,
+ semanticName);
+ return result;
+ }
+ return result;
+ }
+
+
+ List<SystemValLegalizationWorkItem> collectSystemValFromEntryPoint(
+ EntryPointInfo entryPoint) const SLANG_OVERRIDE
+ {
+ List<SystemValLegalizationWorkItem> systemValWorkItems;
+ for (auto param : entryPoint.entryPointFunc->getParams())
+ {
+ auto maybeWorkItem = tryToMakeSystemValWorkItem(param, param->getFullType());
+ if (maybeWorkItem.has_value())
+ systemValWorkItems.add(std::move(maybeWorkItem.value()));
+ }
+
+ return systemValWorkItems;
+ }
+
+ void flattenNestedStructsTransferKeyDecorations(IRInst* newKey, IRInst* oldKey) const
+ SLANG_OVERRIDE
+ {
+ copyNameHintAndDebugDecorations(newKey, oldKey);
+ }
+
+ UnownedStringSlice getUserSemanticNameSlice(String& loweredName, bool isUserSemantic) const
+ SLANG_OVERRIDE
+ {
+ SLANG_UNUSED(isUserSemantic);
+ return loweredName.getUnownedSlice();
+ };
+
+ void addFragmentShaderReturnValueDecoration(IRBuilder& builder, IRInst* returnValueStructKey)
+ const SLANG_OVERRIDE
+ {
+ builder.addTargetSystemValueDecoration(returnValueStructKey, toSlice("color(0)"));
+ }
+
+ IRVarLayout* handleGeometryStageParameterVarLayout(
+ IRBuilder& builder,
+ IRVarLayout* paramVarLayout) const SLANG_OVERRIDE
+ {
+ // For Metal geometric stages, we need to translate VaryingInput offsets to
+ // MetalAttribute offsets.
+ IRVarLayout::Builder elementVarLayoutBuilder(&builder, paramVarLayout->getTypeLayout());
+ elementVarLayoutBuilder.cloneEverythingButOffsetsFrom(paramVarLayout);
+ for (auto offsetAttr : paramVarLayout->getOffsetAttrs())
+ {
+ auto resourceKind = offsetAttr->getResourceKind();
+ if (resourceKind == LayoutResourceKind::VaryingInput)
+ {
+ resourceKind = LayoutResourceKind::MetalAttribute;
+ }
+ auto resInfo = elementVarLayoutBuilder.findOrAddResourceInfo(resourceKind);
+ resInfo->offset = offsetAttr->getOffset();
+ resInfo->space = offsetAttr->getSpace();
+ }
+
+ return elementVarLayoutBuilder.build();
+ }
+
+ void handleSpecialSystemValue(
+ const EntryPointInfo& entryPoint,
+ SystemValLegalizationWorkItem& workItem,
+ const SystemValueInfo& info,
+ IRBuilder& builder) SLANG_OVERRIDE
+ {
+ const auto var = workItem.var;
+
+ if (info.systemValueNameEnum == SystemValueSemanticName::InnerCoverage)
+ {
+ // Metal does not support conservative rasterization, so this is always false.
+ auto val = builder.getBoolValue(false);
+ var->replaceUsesWith(val);
+ var->removeAndDeallocate();
+ }
+ else if (info.systemValueNameEnum == SystemValueSemanticName::GroupIndex)
+ {
+ // Ensure we have a cached "sv_groupthreadid" in our entry point
+ if (!entryPointToGroupThreadId.containsKey(entryPoint.entryPointFunc))
+ {
+ auto systemValWorkItems = collectSystemValFromEntryPoint(entryPoint);
+ for (auto i : systemValWorkItems)
+ {
+ auto indexAsStringGroupThreadId = String(i.attrIndex);
+ if (getSystemValueInfo(i.attrName, &indexAsStringGroupThreadId, i.var)
+ .systemValueNameEnum == SystemValueSemanticName::GroupThreadID)
+ {
+ entryPointToGroupThreadId[entryPoint.entryPointFunc] = i.var;
+ }
+ }
+ if (!entryPointToGroupThreadId.containsKey(entryPoint.entryPointFunc))
+ {
+ // Add the missing groupthreadid needed to compute sv_groupindex
+ IRBuilder groupThreadIdBuilder(builder);
+ groupThreadIdBuilder.setInsertInto(entryPoint.entryPointFunc->getFirstBlock());
+ auto groupThreadId = groupThreadIdBuilder.emitParamAtHead(
+ getGroupThreadIdType(groupThreadIdBuilder));
+ entryPointToGroupThreadId[entryPoint.entryPointFunc] = groupThreadId;
+ groupThreadIdBuilder.addNameHintDecoration(groupThreadId, groupThreadIDString);
+
+ // Since "sv_groupindex" will be translated out to a global var and no
+ // longer be considered a system value we can reuse its layout and
+ // semantic info
+ Index foundRequiredDecorations = 0;
+ IRLayoutDecoration* layoutDecoration = nullptr;
+ UInt semanticIndex = 0;
+ for (auto decoration : var->getDecorations())
+ {
+ if (auto layoutDecorationTmp = as<IRLayoutDecoration>(decoration))
+ {
+ layoutDecoration = layoutDecorationTmp;
+ foundRequiredDecorations++;
+ }
+ else if (auto semanticDecoration = as<IRSemanticDecoration>(decoration))
+ {
+ semanticIndex = semanticDecoration->getSemanticIndex();
+ groupThreadIdBuilder.addSemanticDecoration(
+ groupThreadId,
+ groupThreadIDString,
+ (int)semanticIndex);
+ foundRequiredDecorations++;
+ }
+ if (foundRequiredDecorations >= 2)
+ break;
+ }
+ SLANG_ASSERT(layoutDecoration);
+ layoutDecoration->removeFromParent();
+ layoutDecoration->insertAtStart(groupThreadId);
+ SystemValLegalizationWorkItem newWorkItem = {
+ groupThreadId,
+ groupThreadId->getFullType(),
+ groupThreadIDString,
+ semanticIndex};
+ legalizeSystemValue(entryPoint, newWorkItem);
+ }
+ }
+
+ IRBuilder svBuilder(builder.getModule());
+ svBuilder.setInsertBefore(entryPoint.entryPointFunc->getFirstOrdinaryInst());
+ auto uint3Type = builder.getVectorType(
+ builder.getUIntType(),
+ builder.getIntValue(builder.getIntType(), 3));
+ auto computeExtent =
+ emitCalcGroupExtents(svBuilder, entryPoint.entryPointFunc, uint3Type);
+ if (!computeExtent)
+ {
+ m_sink->diagnose(
+ entryPoint.entryPointFunc,
+ Diagnostics::unsupportedSpecializationConstantForNumThreads);
+
+ // Fill in placeholder values.
+ static const int kAxisCount = 3;
+ IRInst* groupExtentAlongAxis[kAxisCount] = {};
+ for (int axis = 0; axis < kAxisCount; axis++)
+ groupExtentAlongAxis[axis] =
+ builder.getIntValue(uint3Type->getElementType(), 1);
+ computeExtent = builder.emitMakeVector(uint3Type, kAxisCount, groupExtentAlongAxis);
+ }
+ auto groupIndexCalc = emitCalcGroupIndex(
+ svBuilder,
+ entryPointToGroupThreadId[entryPoint.entryPointFunc],
+ computeExtent);
+ svBuilder.addNameHintDecoration(groupIndexCalc, UnownedStringSlice("sv_groupindex"));
+
+ var->replaceUsesWith(groupIndexCalc);
+ var->removeAndDeallocate();
+ }
+ }
+
+ void legalizeAmplificationStageEntryPoint(const EntryPointInfo& entryPoint) const SLANG_OVERRIDE
+ {
+ // Find out DispatchMesh function
+ IRGlobalValueWithCode* dispatchMeshFunc = nullptr;
+ for (const auto globalInst : entryPoint.entryPointFunc->getModule()->getGlobalInsts())
+ {
+ if (const auto func = as<IRGlobalValueWithCode>(globalInst))
+ {
+ if (const auto dec = func->findDecoration<IRKnownBuiltinDecoration>())
+ {
+ if (dec->getName() == "DispatchMesh")
+ {
+ SLANG_ASSERT(!dispatchMeshFunc && "Multiple DispatchMesh functions found");
+ dispatchMeshFunc = func;
+ }
+ }
+ }
+ }
+
+ if (!dispatchMeshFunc)
+ return;
+
+ IRBuilder builder{entryPoint.entryPointFunc->getModule()};
+
+ // We'll rewrite the call to use mesh_grid_properties.set_threadgroups_per_grid
+ traverseUses(
+ dispatchMeshFunc,
+ [&](const IRUse* use)
+ {
+ if (const auto call = as<IRCall>(use->getUser()))
+ {
+ SLANG_ASSERT(call->getArgCount() == 4);
+ const auto payload = call->getArg(3);
+
+ const auto payloadPtrType =
+ composeGetters<IRPtrType>(payload, &IRInst::getDataType);
+ SLANG_ASSERT(payloadPtrType);
+ const auto payloadType = payloadPtrType->getValueType();
+ SLANG_ASSERT(payloadType);
+
+ builder.setInsertBefore(
+ entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst());
+ const auto annotatedPayloadType = builder.getPtrType(
+ kIROp_RefType,
+ payloadPtrType->getValueType(),
+ AddressSpace::MetalObjectData);
+ auto packedParam = builder.emitParam(annotatedPayloadType);
+ builder.addExternCppDecoration(packedParam, toSlice("_slang_mesh_payload"));
+ IRVarLayout::Builder varLayoutBuilder(
+ &builder,
+ IRTypeLayout::Builder{&builder}.build());
+
+ // Add the MetalPayload resource info, so we can emit [[payload]]
+ varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload);
+ auto paramVarLayout = varLayoutBuilder.build();
+ builder.addLayoutDecoration(packedParam, paramVarLayout);
+
+ // Now we replace the call to DispatchMesh with a call to the mesh grid
+ // properties But first we need to create the parameter
+ const auto meshGridPropertiesType = builder.getMetalMeshGridPropertiesType();
+ auto mgp = builder.emitParam(meshGridPropertiesType);
+ builder.addExternCppDecoration(mgp, toSlice("_slang_mgp"));
+ }
+ });
+ }
+
+ void legalizeMeshStageEntryPoint(const EntryPointInfo& entryPoint) const SLANG_OVERRIDE
+ {
+ auto func = entryPoint.entryPointFunc;
+
+ IRBuilder builder{func->getModule()};
+ for (auto param : func->getParams())
+ {
+ if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration))
+ {
+ IRVarLayout::Builder varLayoutBuilder(
+ &builder,
+ IRTypeLayout::Builder{&builder}.build());
+
+ varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload);
+ auto paramVarLayout = varLayoutBuilder.build();
+ builder.addLayoutDecoration(param, paramVarLayout);
+
+ IRPtrTypeBase* type = as<IRPtrTypeBase>(param->getDataType());
+
+ const auto annotatedPayloadType = builder.getPtrType(
+ kIROp_ConstRefType,
+ type->getValueType(),
+ AddressSpace::MetalObjectData);
+
+ param->setFullType(annotatedPayloadType);
+ }
+ }
+ IROutputTopologyDecoration* outputDeco =
+ entryPoint.entryPointFunc->findDecoration<IROutputTopologyDecoration>();
+ if (outputDeco == nullptr)
+ {
+ SLANG_UNEXPECTED("Mesh shader output decoration missing");
+ return;
+ }
+ const auto topology = outputDeco->getTopology();
+ const auto topStr = topology->getStringSlice();
+ UInt topologyEnum = 0;
+ if (topStr.caseInsensitiveEquals(toSlice("point")))
+ {
+ topologyEnum = 1;
+ }
+ else if (topStr.caseInsensitiveEquals(toSlice("line")))
+ {
+ topologyEnum = 2;
+ }
+ else if (topStr.caseInsensitiveEquals(toSlice("triangle")))
+ {
+ topologyEnum = 3;
+ }
+ else
+ {
+ SLANG_UNEXPECTED("unknown topology");
+ return;
+ }
+
+ IRInst* topologyConst = builder.getIntValue(builder.getIntType(), topologyEnum);
+
+ IRType* vertexType = nullptr;
+ IRType* indicesType = nullptr;
+ IRType* primitiveType = nullptr;
+
+ IRInst* maxVertices = nullptr;
+ IRInst* maxPrimitives = nullptr;
+
+ IRInst* verticesParam = nullptr;
+ IRInst* indicesParam = nullptr;
+ IRInst* primitivesParam = nullptr;
+ for (auto param : func->getParams())
+ {
+ if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration))
+ {
+ IRVarLayout::Builder varLayoutBuilder(
+ &builder,
+ IRTypeLayout::Builder{&builder}.build());
+
+ varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload);
+ auto paramVarLayout = varLayoutBuilder.build();
+ builder.addLayoutDecoration(param, paramVarLayout);
+ }
+ if (param->findDecorationImpl(kIROp_VerticesDecoration))
+ {
+ auto vertexRefType = as<IRPtrTypeBase>(param->getDataType());
+ auto vertexOutputType = as<IRVerticesType>(vertexRefType->getValueType());
+ vertexType = vertexOutputType->getElementType();
+ maxVertices = vertexOutputType->getMaxElementCount();
+ SLANG_ASSERT(vertexType);
+
+ verticesParam = param;
+ auto vertStruct = as<IRStructType>(vertexType);
+ for (auto field : vertStruct->getFields())
+ {
+ auto key = field->getKey();
+ if (auto deco = key->findDecoration<IRSemanticDecoration>())
+ {
+ if (deco->getSemanticName().caseInsensitiveEquals(toSlice("sv_position")))
+ {
+ builder.addTargetSystemValueDecoration(key, toSlice("position"));
+ }
+ }
+ }
+ }
+ if (param->findDecorationImpl(kIROp_IndicesDecoration))
+ {
+ auto indicesRefType = (IRConstRefType*)param->getDataType();
+ auto indicesOutputType = (IRIndicesType*)indicesRefType->getValueType();
+ indicesType = indicesOutputType->getElementType();
+ maxPrimitives = indicesOutputType->getMaxElementCount();
+ SLANG_ASSERT(indicesType);
+
+ indicesParam = param;
+ }
+ if (param->findDecorationImpl(kIROp_PrimitivesDecoration))
+ {
+ auto primitivesRefType = (IRConstRefType*)param->getDataType();
+ auto primitivesOutputType = (IRPrimitivesType*)primitivesRefType->getValueType();
+ primitiveType = primitivesOutputType->getElementType();
+ SLANG_ASSERT(primitiveType);
+
+ primitivesParam = param;
+ auto primStruct = as<IRStructType>(primitiveType);
+ for (auto field : primStruct->getFields())
+ {
+ auto key = field->getKey();
+ if (auto deco = key->findDecoration<IRSemanticDecoration>())
+ {
+ if (deco->getSemanticName().caseInsensitiveEquals(
+ toSlice("sv_primitiveid")))
+ {
+ builder.addTargetSystemValueDecoration(key, toSlice("primitive_id"));
+ }
+ }
+ }
+ }
+ }
+ if (primitiveType == nullptr)
+ {
+ primitiveType = builder.getVoidType();
+ }
+ builder.setInsertBefore(entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst());
+
+ auto meshParam = builder.emitParam(builder.getMetalMeshType(
+ vertexType,
+ primitiveType,
+ maxVertices,
+ maxPrimitives,
+ topologyConst));
+ builder.addExternCppDecoration(meshParam, toSlice("_slang_mesh"));
+
+
+ verticesParam->removeFromParent();
+ verticesParam->removeAndDeallocate();
+
+ indicesParam->removeFromParent();
+ indicesParam->removeAndDeallocate();
+
+ if (primitivesParam != nullptr)
+ {
+ primitivesParam->removeFromParent();
+ primitivesParam->removeAndDeallocate();
+ }
+ }
+
+private:
+ ShortList<IRType*> permittedTypes_sv_target;
+ Dictionary<IRFunc*, IRInst*> entryPointToGroupThreadId;
+ const UnownedStringSlice groupThreadIDString = UnownedStringSlice("sv_groupthreadid");
+
+ static IRType* getGroupThreadIdType(IRBuilder& builder)
+ {
+ return builder.getVectorType(
+ builder.getBasicType(BaseType::UInt),
+ builder.getIntValue(builder.getIntType(), 3));
+ }
+
+ void generatePermittedTypes_sv_target()
+ {
+ IRBuilder builder(m_module);
+ permittedTypes_sv_target.reserveOverflowBuffer(5 * 4);
+ if (permittedTypes_sv_target.getCount() == 0)
+ {
+ for (auto baseType :
+ {BaseType::Float,
+ BaseType::Half,
+ BaseType::Int,
+ BaseType::UInt,
+ BaseType::Int16,
+ BaseType::UInt16})
+ {
+ for (IRIntegerValue i = 1; i <= 4; i++)
+ {
+ permittedTypes_sv_target.add(
+ builder.getVectorType(builder.getBasicType(baseType), i));
+ }
+ }
+ }
+ }
+};
+
+
+class LegalizeWGSLEntryPointContext : public LegalizeShaderEntryPointContext
+{
+public:
+ LegalizeWGSLEntryPointContext(IRModule* module, DiagnosticSink* sink)
+ : LegalizeShaderEntryPointContext(module, sink, false)
+ {
+ }
+
+protected:
+ SystemValueInfo getSystemValueInfo(
+ String inSemanticName,
+ String* optionalSemanticIndex,
+ IRInst* parentVar) const SLANG_OVERRIDE
+ {
+ IRBuilder builder(m_module);
+ SystemValueInfo result = {};
+ UnownedStringSlice semanticName;
+ UnownedStringSlice semanticIndex;
+
+ auto hasExplicitIndex =
+ splitNameAndIndex(inSemanticName.getUnownedSlice(), semanticName, semanticIndex);
+ if (!hasExplicitIndex && optionalSemanticIndex)
+ semanticIndex = optionalSemanticIndex->getUnownedSlice();
+
+ result.systemValueNameEnum = convertSystemValueSemanticNameToEnum(semanticName);
+
+ switch (result.systemValueNameEnum)
+ {
+
+ case SystemValueSemanticName::CullDistance:
+ {
+ result.isUnsupported = true;
+ }
+ break;
+
+ case SystemValueSemanticName::ClipDistance:
+ {
+ // TODO: Implement this based on the 'clip-distances' feature in WGSL
+ // https: // www.w3.org/TR/webgpu/#dom-gpufeaturename-clip-distances
+ result.isUnsupported = true;
+ }
+ break;
+
+ case SystemValueSemanticName::Coverage:
+ {
+ result.systemValueName = toSlice("sample_mask");
+ result.permittedTypes.add(builder.getUIntType());
+ }
+ break;
+
+ case SystemValueSemanticName::Depth:
+ {
+ result.systemValueName = toSlice("frag_depth");
+ result.permittedTypes.add(builder.getBasicType(BaseType::Float));
+ }
+ break;
+
+ case SystemValueSemanticName::DepthGreaterEqual:
+ case SystemValueSemanticName::DepthLessEqual:
+ {
+ result.isUnsupported = true;
+ }
+ break;
+
+ case SystemValueSemanticName::DispatchThreadID:
+ {
+ result.systemValueName = toSlice("global_invocation_id");
+ result.permittedTypes.add(builder.getVectorType(
+ builder.getBasicType(BaseType::UInt),
+ builder.getIntValue(builder.getIntType(), 3)));
+ }
+ break;
+
+ case SystemValueSemanticName::DomainLocation:
+ {
+ result.isUnsupported = true;
+ }
+ break;
+
+ case SystemValueSemanticName::GroupID:
+ {
+ result.systemValueName = toSlice("workgroup_id");
+ result.permittedTypes.add(builder.getVectorType(
+ builder.getBasicType(BaseType::UInt),
+ builder.getIntValue(builder.getIntType(), 3)));
+ }
+ break;
+
+ case SystemValueSemanticName::GroupIndex:
+ {
+ result.systemValueName = toSlice("local_invocation_index");
+ result.permittedTypes.add(builder.getUIntType());
+ }
+ break;
+
+ case SystemValueSemanticName::GroupThreadID:
+ {
+ result.systemValueName = toSlice("local_invocation_id");
+ result.permittedTypes.add(builder.getVectorType(
+ builder.getBasicType(BaseType::UInt),
+ builder.getIntValue(builder.getIntType(), 3)));
+ }
+ break;
+
+ case SystemValueSemanticName::GSInstanceID:
+ {
+ // No Geometry shaders in WGSL
+ result.isUnsupported = true;
+ }
+ break;
+
+ case SystemValueSemanticName::InnerCoverage:
+ {
+ result.isUnsupported = true;
+ }
+ break;
+
+ case SystemValueSemanticName::InstanceID:
+ {
+ result.systemValueName = toSlice("instance_index");
+ result.permittedTypes.add(builder.getUIntType());
+ }
+ break;
+
+ case SystemValueSemanticName::IsFrontFace:
+ {
+ result.systemValueName = toSlice("front_facing");
+ result.permittedTypes.add(builder.getBoolType());
+ }
+ break;
+
+ case SystemValueSemanticName::OutputControlPointID:
+ case SystemValueSemanticName::PointSize:
+ {
+ result.isUnsupported = true;
+ }
+ break;
+
+ case SystemValueSemanticName::Position:
+ {
+ result.systemValueName = toSlice("position");
+ result.permittedTypes.add(builder.getVectorType(
+ builder.getBasicType(BaseType::Float),
+ builder.getIntValue(builder.getIntType(), 4)));
+ break;
+ }
+
+ case SystemValueSemanticName::PrimitiveID:
+ case SystemValueSemanticName::RenderTargetArrayIndex:
+ {
+ result.isUnsupported = true;
+ break;
+ }
+
+ case SystemValueSemanticName::SampleIndex:
+ {
+ result.systemValueName = toSlice("sample_index");
+ result.permittedTypes.add(builder.getUIntType());
+ break;
+ }
+
+ case SystemValueSemanticName::StencilRef:
+ case SystemValueSemanticName::Target:
+ case SystemValueSemanticName::TessFactor:
+ {
+ result.isUnsupported = true;
+ break;
+ }
+
+ case SystemValueSemanticName::VertexID:
+ {
+ result.systemValueName = toSlice("vertex_index");
+ result.permittedTypes.add(builder.getUIntType());
+ break;
+ }
+
+ case SystemValueSemanticName::ViewID:
+ case SystemValueSemanticName::ViewportArrayIndex:
+ case SystemValueSemanticName::StartVertexLocation:
+ case SystemValueSemanticName::StartInstanceLocation:
+ {
+ result.isUnsupported = true;
+ break;
+ }
+
+ default:
+ {
+ m_sink->diagnose(
+ parentVar,
+ Diagnostics::unimplementedSystemValueSemantic,
+ semanticName);
+ return result;
+ }
+ }
+
+ return result;
+ }
+ void flattenNestedStructsTransferKeyDecorations(IRInst* newKey, IRInst* oldKey) const
+ SLANG_OVERRIDE
+ {
+ oldKey->transferDecorationsTo(newKey);
+ }
+
+ UnownedStringSlice getUserSemanticNameSlice(String& loweredName, bool isUserSemantic) const
+ SLANG_OVERRIDE
+ {
+ return isUserSemantic ? userSemanticName : loweredName.getUnownedSlice();
+ }
+
+ void addFragmentShaderReturnValueDecoration(IRBuilder& builder, IRInst* returnValueStructKey)
+ const SLANG_OVERRIDE
+ {
+ IRInst* operands[] = {
+ builder.getStringValue(userSemanticName),
+ builder.getIntValue(builder.getIntType(), 0)};
+ builder.addDecoration(
+ returnValueStructKey,
+ kIROp_SemanticDecoration,
+ operands,
+ SLANG_COUNT_OF(operands));
+ };
+
+ List<SystemValLegalizationWorkItem> collectSystemValFromEntryPoint(
+ EntryPointInfo entryPoint) const SLANG_OVERRIDE
+ {
+ List<SystemValLegalizationWorkItem> systemValWorkItems;
+ for (auto param : entryPoint.entryPointFunc->getParams())
+ {
+ if (auto structType = as<IRStructType>(param->getDataType()))
+ {
+ for (auto field : structType->getFields())
+ {
+ // Nested struct-s are flattened already by flattenInputParameters().
+ SLANG_ASSERT(!as<IRStructType>(field->getFieldType()));
+
+ auto key = field->getKey();
+ auto fieldType = field->getFieldType();
+ auto maybeWorkItem = tryToMakeSystemValWorkItem(key, fieldType);
+ if (maybeWorkItem.has_value())
+ systemValWorkItems.add(std::move(maybeWorkItem.value()));
+ }
+ continue;
+ }
+
+ auto maybeWorkItem = tryToMakeSystemValWorkItem(param, param->getFullType());
+ if (maybeWorkItem.has_value())
+ systemValWorkItems.add(std::move(maybeWorkItem.value()));
+ }
+
+ return systemValWorkItems;
+ }
+
+private:
+ const UnownedStringSlice userSemanticName = toSlice("user_semantic");
+};
+
+void legalizeEntryPointVaryingParamsForMetal(
+ IRModule* module,
+ DiagnosticSink* sink,
+ List<EntryPointInfo>& entryPoints)
+{
+ LegalizeMetalEntryPointContext context(module, sink);
+ context.legalizeEntryPoints(entryPoints);
+}
+
+void legalizeEntryPointVaryingParamsForWGSL(
+ IRModule* module,
+ DiagnosticSink* sink,
+ List<EntryPointInfo>& entryPoints)
+{
+ LegalizeWGSLEntryPointContext context(module, sink);
+ context.legalizeEntryPoints(entryPoints);
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-legalize-varying-params.h b/source/slang/slang-ir-legalize-varying-params.h
index efd61e87c..e742f3093 100644
--- a/source/slang/slang-ir-legalize-varying-params.h
+++ b/source/slang/slang-ir-legalize-varying-params.h
@@ -14,19 +14,27 @@ struct IRVectorType;
struct IRBuilder;
struct IREntryPointDecoration;
+struct EntryPointInfo
+{
+ IRFunc* entryPointFunc;
+ IREntryPointDecoration* entryPointDecor;
+};
+
void legalizeEntryPointVaryingParamsForCPU(IRModule* module, DiagnosticSink* sink);
void legalizeEntryPointVaryingParamsForCUDA(IRModule* module, DiagnosticSink* sink);
-void depointerizeInputParams(IRFunc* entryPoint);
-
-// (#4375) Once `slang-ir-metal-legalize.cpp` is merged with
-// `slang-ir-legalize-varying-params.cpp`, move the following
-// below into `slang-ir-legalize-varying-params.cpp` as well
+void legalizeEntryPointVaryingParamsForMetal(
+ IRModule* module,
+ DiagnosticSink* sink,
+ List<EntryPointInfo>& entryPoints);
-IRInst* emitCalcGroupExtents(IRBuilder& builder, IRFunc* entryPoint, IRVectorType* type);
+void legalizeEntryPointVaryingParamsForWGSL(
+ IRModule* module,
+ DiagnosticSink* sink,
+ List<EntryPointInfo>& entryPoints);
-IRInst* emitCalcGroupIndex(IRBuilder& builder, IRInst* groupThreadID, IRInst* groupExtents);
+void depointerizeInputParams(IRFunc* entryPoint);
// SystemValueSemanticName member definition macro
#define SYSTEM_VALUE_SEMANTIC_NAMES(M) \
diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp
index 3b47bd59e..0d58bdd14 100644
--- a/source/slang/slang-ir-metal-legalize.cpp
+++ b/source/slang/slang-ir-metal-legalize.cpp
@@ -7,1964 +7,10 @@
#include "slang-ir-specialize-address-space.h"
#include "slang-ir-util.h"
#include "slang-ir.h"
-#include "slang-parameter-binding.h"
-
-#include <set>
namespace Slang
{
-struct EntryPointInfo
-{
- IRFunc* entryPointFunc;
- IREntryPointDecoration* entryPointDecor;
-};
-
-const UnownedStringSlice groupThreadIDString = UnownedStringSlice("sv_groupthreadid");
-struct LegalizeMetalEntryPointContext
-{
- ShortList<IRType*> permittedTypes_sv_target;
- Dictionary<IRFunc*, IRInst*> entryPointToGroupThreadId;
- HashSet<IRStructField*> semanticInfoToRemove;
-
- DiagnosticSink* m_sink;
- IRModule* m_module;
-
- LegalizeMetalEntryPointContext(DiagnosticSink* sink, IRModule* module)
- : m_sink(sink), m_module(module)
- {
- }
-
- void removeSemanticLayoutsFromLegalizedStructs()
- {
- // Metal does not allow duplicate attributes to appear in the same shader.
- // If we emit our own struct with `[[color(0)]`, all existing uses of `[[color(0)]]`
- // must be removed.
- for (auto field : semanticInfoToRemove)
- {
- auto key = field->getKey();
- // Some decorations appear twice, destroy all found
- for (;;)
- {
- if (auto semanticDecor = key->findDecoration<IRSemanticDecoration>())
- {
- semanticDecor->removeAndDeallocate();
- continue;
- }
- else if (auto layoutDecor = key->findDecoration<IRLayoutDecoration>())
- {
- layoutDecor->removeAndDeallocate();
- continue;
- }
- break;
- }
- }
- }
-
- void hoistEntryPointParameterFromStruct(EntryPointInfo entryPoint)
- {
- // If an entry point has a input parameter with a struct type, we want to hoist out
- // all the fields of the struct type to be individual parameters of the entry point.
- // This will canonicalize the entry point signature, so we can handle all cases uniformly.
-
- // For example, given an entry point:
- // ```
- // struct VertexInput { float3 pos; float 2 uv; int vertexId : SV_VertexID};
- // void main(VertexInput vin) { ... }
- // ```
- // We will transform it to:
- // ```
- // void main(float3 pos, float2 uv, int vertexId : SV_VertexID) {
- // VertexInput vin = {pos,uv,vertexId};
- // ...
- // }
- // ```
-
- auto func = entryPoint.entryPointFunc;
- List<IRParam*> paramsToProcess;
- for (auto param : func->getParams())
- {
- if (as<IRStructType>(param->getDataType()))
- {
- paramsToProcess.add(param);
- }
- }
-
- IRBuilder builder(func);
- builder.setInsertBefore(func);
- for (auto param : paramsToProcess)
- {
- auto structType = as<IRStructType>(param->getDataType());
- builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
- auto varLayout = findVarLayout(param);
-
- // If `param` already has a semantic, we don't want to hoist its fields out.
- if (varLayout->findSystemValueSemanticAttr() != nullptr ||
- param->findDecoration<IRSemanticDecoration>())
- continue;
-
- IRStructTypeLayout* structTypeLayout = nullptr;
- if (varLayout)
- structTypeLayout = as<IRStructTypeLayout>(varLayout->getTypeLayout());
- Index fieldIndex = 0;
- List<IRInst*> fieldParams;
- for (auto field : structType->getFields())
- {
- auto fieldParam = builder.emitParam(field->getFieldType());
- IRCloneEnv cloneEnv;
- cloneInstDecorationsAndChildren(
- &cloneEnv,
- builder.getModule(),
- field->getKey(),
- fieldParam);
-
- IRVarLayout* fieldLayout =
- structTypeLayout ? structTypeLayout->getFieldLayout(fieldIndex) : nullptr;
- if (varLayout)
- {
- IRVarLayout::Builder varLayoutBuilder(&builder, fieldLayout->getTypeLayout());
- varLayoutBuilder.cloneEverythingButOffsetsFrom(fieldLayout);
- for (auto offsetAttr : fieldLayout->getOffsetAttrs())
- {
- auto parentOffsetAttr =
- varLayout->findOffsetAttr(offsetAttr->getResourceKind());
- UInt parentOffset = parentOffsetAttr ? parentOffsetAttr->getOffset() : 0;
- UInt parentSpace = parentOffsetAttr ? parentOffsetAttr->getSpace() : 0;
- auto resInfo =
- varLayoutBuilder.findOrAddResourceInfo(offsetAttr->getResourceKind());
- resInfo->offset = parentOffset + offsetAttr->getOffset();
- resInfo->space = parentSpace + offsetAttr->getSpace();
- }
- builder.addLayoutDecoration(fieldParam, varLayoutBuilder.build());
- }
- param->insertBefore(fieldParam);
- fieldParams.add(fieldParam);
- fieldIndex++;
- }
- builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
- auto reconstructedParam =
- builder.emitMakeStruct(structType, fieldParams.getCount(), fieldParams.getBuffer());
- param->replaceUsesWith(reconstructedParam);
- param->removeFromParent();
- }
- fixUpFuncType(func);
- }
-
- // Flattens all struct parameters of an entryPoint to ensure parameters are a flat struct
- void flattenInputParameters(EntryPointInfo entryPoint)
- {
- // Goal is to ensure we have a flattened IRParam (0 nested IRStructType members).
- /*
- // Assume the following code
- struct NestedFragment
- {
- float2 p3;
- };
- struct Fragment
- {
- float4 p1;
- float3 p2;
- NestedFragment p3_nested;
- };
-
- // Fragment flattens into
- struct Fragment
- {
- float4 p1;
- float3 p2;
- float2 p3;
- };
- */
-
- // This is important since Metal does not allow semantic's on a struct
- /*
- // Assume the following code
- struct NestedFragment1
- {
- float2 p3;
- };
- struct Fragment1
- {
- float4 p1 : SV_TARGET0;
- float3 p2 : SV_TARGET1;
- NestedFragment p3_nested : SV_TARGET2; // error, semantic on struct
- };
-
- */
-
- // Metal does allow semantics on members of a nested struct but we are avoiding this
- // approach since there are senarios where legalization (and verification) is
- // hard/expensive without creating a flat struct:
- // 1. Entry points may share structs, semantics may be inconsistent across entry points
- // 2. Multiple of the same struct may be used in a param list
- /*
- // Assume the following code
- struct NestedFragment
- {
- float2 p3;
- };
- struct Fragment
- {
- float4 p1 : SV_TARGET0;
- NestedFragment p2 : SV_TARGET1;
- NestedFragment p3 : SV_TARGET2;
- };
-
- // Legalized without flattening -- abandoned
- struct NestedFragment1
- {
- float2 p3 : SV_TARGET1;
- };
- struct NestedFragment2
- {
- float2 p3 : SV_TARGET2;
- };
- struct Fragment
- {
- float4 p1 : SV_TARGET0;
- NestedFragment1 p2;
- NestedFragment2 p3;
- };
-
- // Legalized with flattening -- current approach
- struct Fragment
- {
- float4 p1 : SV_TARGET0;
- float2 p2 : SV_TARGET1;
- float2 p3 : SV_TARGET2;
- };
- */
-
- auto func = entryPoint.entryPointFunc;
- bool modified = false;
- for (auto param : func->getParams())
- {
- auto layout = findVarLayout(param);
- if (!layout)
- continue;
- if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput))
- continue;
- if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration))
- continue;
- // If we find a IRParam with a IRStructType member, we need to flatten the entire
- // IRParam
- if (auto structType = as<IRStructType>(param->getDataType()))
- {
- IRBuilder builder(func);
- MapStructToFlatStruct mapOldFieldToNewField;
-
- // Flatten struct if we have nested IRStructType
- auto flattenedStruct = maybeFlattenNestedStructs(
- builder,
- structType,
- mapOldFieldToNewField,
- semanticInfoToRemove);
- if (flattenedStruct != structType)
- {
- // Validate/rearange all semantics which overlap in our flat struct
- fixFieldSemanticsOfFlatStruct(flattenedStruct);
-
- // Replace the 'old IRParam type' with a 'new IRParam type'
- param->setFullType(flattenedStruct);
-
- // Emit a new variable at EntryPoint of 'old IRParam type'
- builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
- auto dstVal = builder.emitVar(structType);
- auto dstLoad = builder.emitLoad(dstVal);
- param->replaceUsesWith(dstLoad);
- builder.setInsertBefore(dstLoad);
- // Copy the 'new IRParam type' to our 'old IRParam type'
- mapOldFieldToNewField
- .emitCopy<(int)MapStructToFlatStruct::CopyOptions::FlatStructIntoStruct>(
- builder,
- dstVal,
- param);
-
- modified = true;
- }
- }
- }
- if (modified)
- fixUpFuncType(func);
- }
-
- void packStageInParameters(EntryPointInfo entryPoint)
- {
- // If the entry point has any parameters whose layout contains VaryingInput,
- // we need to pack those parameters into a single `struct` type, and decorate
- // the fields with the appropriate `[[attribute]]` decorations.
- // For other parameters that are not `VaryingInput`, we need to leave them as is.
- //
- // For example, given this code after `hoistEntryPointParameterFromStruct`:
- // ```
- // void main(float3 pos, float2 uv, int vertexId : SV_VertexID) {
- // VertexInput vin = {pos,uv,vertexId};
- // ...
- // }
- // ```
- // We are going to transform it into:
- // ```
- // struct VertexInput {
- // float3 pos [[attribute(0)]];
- // float2 uv [[attribute(1)]];
- // };
- // void main(VertexInput vin, int vertexId : SV_VertexID) {
- // let pos = vin.pos;
- // let uv = vin.uv;
- // ...
- // }
-
- auto func = entryPoint.entryPointFunc;
-
- bool isGeometryStage = false;
- switch (entryPoint.entryPointDecor->getProfile().getStage())
- {
- case Stage::Vertex:
- case Stage::Amplification:
- case Stage::Mesh:
- case Stage::Geometry:
- case Stage::Domain:
- case Stage::Hull:
- isGeometryStage = true;
- break;
- }
-
- List<IRParam*> paramsToPack;
- for (auto param : func->getParams())
- {
- auto layout = findVarLayout(param);
- if (!layout)
- continue;
- if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput))
- continue;
- if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration))
- continue;
- paramsToPack.add(param);
- }
-
- if (paramsToPack.getCount() == 0)
- return;
-
- IRBuilder builder(func);
- builder.setInsertBefore(func);
- IRStructType* structType = builder.createStructType();
- auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage());
- builder.addNameHintDecoration(
- structType,
- (String(stageText) + toSlice("Input")).getUnownedSlice());
- List<IRStructKey*> keys;
- IRStructTypeLayout::Builder layoutBuilder(&builder);
- for (auto param : paramsToPack)
- {
- auto paramVarLayout = findVarLayout(param);
- auto key = builder.createStructKey();
- param->transferDecorationsTo(key);
- builder.createStructField(structType, key, param->getDataType());
- if (auto varyingInOffsetAttr =
- paramVarLayout->findOffsetAttr(LayoutResourceKind::VaryingInput))
- {
- if (!key->findDecoration<IRSemanticDecoration>() &&
- !paramVarLayout->findAttr<IRSemanticAttr>())
- {
- // If the parameter doesn't have a semantic, we need to add one for semantic
- // matching.
- builder.addSemanticDecoration(
- key,
- toSlice("_slang_attr"),
- (int)varyingInOffsetAttr->getOffset());
- }
- }
- if (isGeometryStage)
- {
- // For geometric stages, we need to translate VaryingInput offsets to MetalAttribute
- // offsets.
- IRVarLayout::Builder elementVarLayoutBuilder(
- &builder,
- paramVarLayout->getTypeLayout());
- elementVarLayoutBuilder.cloneEverythingButOffsetsFrom(paramVarLayout);
- for (auto offsetAttr : paramVarLayout->getOffsetAttrs())
- {
- auto resourceKind = offsetAttr->getResourceKind();
- if (resourceKind == LayoutResourceKind::VaryingInput)
- {
- resourceKind = LayoutResourceKind::MetalAttribute;
- }
- auto resInfo = elementVarLayoutBuilder.findOrAddResourceInfo(resourceKind);
- resInfo->offset = offsetAttr->getOffset();
- resInfo->space = offsetAttr->getSpace();
- }
- paramVarLayout = elementVarLayoutBuilder.build();
- }
- layoutBuilder.addField(key, paramVarLayout);
- builder.addLayoutDecoration(key, paramVarLayout);
- keys.add(key);
- }
- builder.setInsertInto(func->getFirstBlock());
- auto packedParam = builder.emitParamAtHead(structType);
- auto typeLayout = layoutBuilder.build();
- IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout);
-
- // Add a VaryingInput resource info to the packed parameter layout, so that we can emit
- // the needed `[[stage_in]]` attribute in Metal emitter.
- varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::VaryingInput);
- auto paramVarLayout = varLayoutBuilder.build();
- builder.addLayoutDecoration(packedParam, paramVarLayout);
-
- // Replace the original parameters with the packed parameter
- builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
- for (Index paramIndex = 0; paramIndex < paramsToPack.getCount(); paramIndex++)
- {
- auto param = paramsToPack[paramIndex];
- auto key = keys[paramIndex];
- auto paramField = builder.emitFieldExtract(param->getDataType(), packedParam, key);
- param->replaceUsesWith(paramField);
- param->removeFromParent();
- }
- fixUpFuncType(func);
- }
-
- struct MetalSystemValueInfo
- {
- String metalSystemValueName;
- SystemValueSemanticName metalSystemValueNameEnum;
- ShortList<IRType*> permittedTypes;
- bool isUnsupported = false;
- bool isSpecial = false;
- MetalSystemValueInfo()
- {
- // most commonly need 2
- permittedTypes.reserveOverflowBuffer(2);
- }
- };
-
- IRType* getGroupThreadIdType(IRBuilder& builder)
- {
- return builder.getVectorType(
- builder.getBasicType(BaseType::UInt),
- builder.getIntValue(builder.getIntType(), 3));
- }
-
- // Get all permitted types of "sv_target" for Metal
- ShortList<IRType*>& getPermittedTypes_sv_target(IRBuilder& builder)
- {
- permittedTypes_sv_target.reserveOverflowBuffer(5 * 4);
- if (permittedTypes_sv_target.getCount() == 0)
- {
- for (auto baseType :
- {BaseType::Float,
- BaseType::Half,
- BaseType::Int,
- BaseType::UInt,
- BaseType::Int16,
- BaseType::UInt16})
- {
- for (IRIntegerValue i = 1; i <= 4; i++)
- {
- permittedTypes_sv_target.add(
- builder.getVectorType(builder.getBasicType(baseType), i));
- }
- }
- }
- return permittedTypes_sv_target;
- }
-
- MetalSystemValueInfo getSystemValueInfo(
- String inSemanticName,
- String* optionalSemanticIndex,
- IRInst* parentVar)
- {
- IRBuilder builder(m_module);
- MetalSystemValueInfo result = {};
- UnownedStringSlice semanticName;
- UnownedStringSlice semanticIndex;
-
- auto hasExplicitIndex =
- splitNameAndIndex(inSemanticName.getUnownedSlice(), semanticName, semanticIndex);
- if (!hasExplicitIndex && optionalSemanticIndex)
- semanticIndex = optionalSemanticIndex->getUnownedSlice();
-
- result.metalSystemValueNameEnum = convertSystemValueSemanticNameToEnum(semanticName);
-
- switch (result.metalSystemValueNameEnum)
- {
- case SystemValueSemanticName::Position:
- {
- result.metalSystemValueName = toSlice("position");
- result.permittedTypes.add(builder.getVectorType(
- builder.getBasicType(BaseType::Float),
- builder.getIntValue(builder.getIntType(), 4)));
- break;
- }
- case SystemValueSemanticName::ClipDistance:
- {
- result.isSpecial = true;
- break;
- }
- case SystemValueSemanticName::CullDistance:
- {
- result.isSpecial = true;
- break;
- }
- case SystemValueSemanticName::Coverage:
- {
- result.metalSystemValueName = toSlice("sample_mask");
- result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
- break;
- }
- case SystemValueSemanticName::InnerCoverage:
- {
- result.isSpecial = true;
- break;
- }
- case SystemValueSemanticName::Depth:
- {
- result.metalSystemValueName = toSlice("depth(any)");
- result.permittedTypes.add(builder.getBasicType(BaseType::Float));
- break;
- }
- case SystemValueSemanticName::DepthGreaterEqual:
- {
- result.metalSystemValueName = toSlice("depth(greater)");
- result.permittedTypes.add(builder.getBasicType(BaseType::Float));
- break;
- }
- case SystemValueSemanticName::DepthLessEqual:
- {
- result.metalSystemValueName = toSlice("depth(less)");
- result.permittedTypes.add(builder.getBasicType(BaseType::Float));
- break;
- }
- case SystemValueSemanticName::DispatchThreadID:
- {
- result.metalSystemValueName = toSlice("thread_position_in_grid");
- result.permittedTypes.add(builder.getVectorType(
- builder.getBasicType(BaseType::UInt),
- builder.getIntValue(builder.getIntType(), 3)));
- break;
- }
- case SystemValueSemanticName::DomainLocation:
- {
- result.metalSystemValueName = toSlice("position_in_patch");
- result.permittedTypes.add(builder.getVectorType(
- builder.getBasicType(BaseType::Float),
- builder.getIntValue(builder.getIntType(), 3)));
- result.permittedTypes.add(builder.getVectorType(
- builder.getBasicType(BaseType::Float),
- builder.getIntValue(builder.getIntType(), 2)));
- break;
- }
- case SystemValueSemanticName::GroupID:
- {
- result.metalSystemValueName = toSlice("threadgroup_position_in_grid");
- result.permittedTypes.add(builder.getVectorType(
- builder.getBasicType(BaseType::UInt),
- builder.getIntValue(builder.getIntType(), 3)));
- break;
- }
- case SystemValueSemanticName::GroupIndex:
- {
- result.isSpecial = true;
- break;
- }
- case SystemValueSemanticName::GroupThreadID:
- {
- result.metalSystemValueName = toSlice("thread_position_in_threadgroup");
- result.permittedTypes.add(getGroupThreadIdType(builder));
- break;
- }
- case SystemValueSemanticName::GSInstanceID:
- {
- result.isUnsupported = true;
- break;
- }
- case SystemValueSemanticName::InstanceID:
- {
- result.metalSystemValueName = toSlice("instance_id");
- result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
- break;
- }
- case SystemValueSemanticName::IsFrontFace:
- {
- result.metalSystemValueName = toSlice("front_facing");
- result.permittedTypes.add(builder.getBasicType(BaseType::Bool));
- break;
- }
- case SystemValueSemanticName::OutputControlPointID:
- {
- // In metal, a hull shader is just a compute shader.
- // This needs to be handled separately, by lowering into an ordinary buffer.
- break;
- }
- case SystemValueSemanticName::PointSize:
- {
- result.metalSystemValueName = toSlice("point_size");
- result.permittedTypes.add(builder.getBasicType(BaseType::Float));
- break;
- }
- case SystemValueSemanticName::PrimitiveID:
- {
- result.metalSystemValueName = toSlice("primitive_id");
- result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
- result.permittedTypes.add(builder.getBasicType(BaseType::UInt16));
- break;
- }
- case SystemValueSemanticName::RenderTargetArrayIndex:
- {
- result.metalSystemValueName = toSlice("render_target_array_index");
- result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
- result.permittedTypes.add(builder.getBasicType(BaseType::UInt16));
- break;
- }
- case SystemValueSemanticName::SampleIndex:
- {
- result.metalSystemValueName = toSlice("sample_id");
- result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
- break;
- }
- case SystemValueSemanticName::StencilRef:
- {
- result.metalSystemValueName = toSlice("stencil");
- result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
- break;
- }
- case SystemValueSemanticName::TessFactor:
- {
- // Tessellation factor outputs should be lowered into a write into a normal buffer.
- break;
- }
- case SystemValueSemanticName::VertexID:
- {
- result.metalSystemValueName = toSlice("vertex_id");
- result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
- break;
- }
- case SystemValueSemanticName::ViewID:
- {
- result.isUnsupported = true;
- break;
- }
- case SystemValueSemanticName::ViewportArrayIndex:
- {
- result.metalSystemValueName = toSlice("viewport_array_index");
- result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
- result.permittedTypes.add(builder.getBasicType(BaseType::UInt16));
- break;
- }
- case SystemValueSemanticName::Target:
- {
- result.metalSystemValueName =
- (StringBuilder()
- << "color(" << (semanticIndex.getLength() != 0 ? semanticIndex : toSlice("0"))
- << ")")
- .produceString();
- result.permittedTypes = getPermittedTypes_sv_target(builder);
-
- break;
- }
- case SystemValueSemanticName::StartVertexLocation:
- {
- result.metalSystemValueName = toSlice("base_vertex");
- result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
- break;
- }
- case SystemValueSemanticName::StartInstanceLocation:
- {
- result.metalSystemValueName = toSlice("base_instance");
- result.permittedTypes.add(builder.getBasicType(BaseType::UInt));
- break;
- }
- default:
- m_sink->diagnose(
- parentVar,
- Diagnostics::unimplementedSystemValueSemantic,
- semanticName);
- return result;
- }
- return result;
- }
-
- void reportUnsupportedSystemAttribute(IRInst* param, String semanticName)
- {
- m_sink->diagnose(
- param->sourceLoc,
- Diagnostics::systemValueAttributeNotSupported,
- semanticName);
- }
-
- void ensureResultStructHasUserSemantic(IRStructType* structType, IRVarLayout* varLayout)
- {
- // Ensure each field in an output struct type has either a system semantic or a user
- // semantic, so that signature matching can happen correctly.
- auto typeLayout = as<IRStructTypeLayout>(varLayout->getTypeLayout());
- Index index = 0;
- IRBuilder builder(structType);
- for (auto field : structType->getFields())
- {
- auto key = field->getKey();
- if (auto semanticDecor = key->findDecoration<IRSemanticDecoration>())
- {
- if (semanticDecor->getSemanticName().startsWithCaseInsensitive(toSlice("sv_")))
- {
- auto indexAsString = String(UInt(semanticDecor->getSemanticIndex()));
- auto sysValInfo =
- getSystemValueInfo(semanticDecor->getSemanticName(), &indexAsString, field);
- if (sysValInfo.isUnsupported || sysValInfo.isSpecial)
- {
- reportUnsupportedSystemAttribute(field, semanticDecor->getSemanticName());
- }
- else
- {
- builder.addTargetSystemValueDecoration(
- key,
- sysValInfo.metalSystemValueName.getUnownedSlice());
- semanticDecor->removeAndDeallocate();
- }
- }
- index++;
- continue;
- }
- typeLayout->getFieldLayout(index);
- auto fieldLayout = typeLayout->getFieldLayout(index);
- if (auto offsetAttr = fieldLayout->findOffsetAttr(LayoutResourceKind::VaryingOutput))
- {
- UInt varOffset = 0;
- if (auto varOffsetAttr =
- varLayout->findOffsetAttr(LayoutResourceKind::VaryingOutput))
- varOffset = varOffsetAttr->getOffset();
- varOffset += offsetAttr->getOffset();
- builder.addSemanticDecoration(key, toSlice("_slang_attr"), (int)varOffset);
- }
- index++;
- }
- }
-
- // Stores a hicharchy of members and children which map 'oldStruct->member' to
- // 'flatStruct->member' Note: this map assumes we map to FlatStruct since it is easier/faster to
- // process
- struct MapStructToFlatStruct
- {
- /*
- We need a hicharchy map to resolve dependencies for mapping
- oldStruct to newStruct efficently. Example:
-
- MyStruct
- |
- / | \
- / | \
- / | \
- M0<A> M1<A> M2<B>
- | | |
- A_0 A_0 B_0
-
- Without storing hicharchy information, there will be no way to tell apart
- `myStruct.M0.A0` from `myStruct.M1.A0` since IRStructKey/IRStructField
- only has 1 instance of `A::A0`
- */
-
- enum CopyOptions : int
- {
- // Copy a flattened-struct into a struct
- FlatStructIntoStruct = 0,
-
- // Copy a struct into a flattened-struct
- StructIntoFlatStruct = 1,
- };
-
- private:
- // Children of member if applicable.
- Dictionary<IRStructField*, MapStructToFlatStruct> members;
-
- // Field correlating to MapStructToFlatStruct Node.
- IRInst* node;
- IRStructKey* getKey()
- {
- SLANG_ASSERT(as<IRStructField>(node));
- return as<IRStructField>(node)->getKey();
- }
- IRInst* getNode() { return node; }
- IRType* getFieldType()
- {
- SLANG_ASSERT(as<IRStructField>(node));
- return as<IRStructField>(node)->getFieldType();
- }
-
- // Whom node maps to inside target flatStruct
- IRStructField* targetMapping;
-
- auto begin() { return members.begin(); }
- auto end() { return members.end(); }
-
- // Copies members of oldStruct to/from newFlatStruct. Assumes members of val1 maps to
- // members in val2 using `MapStructToFlatStruct`
- template<int copyOptions>
- static void _emitCopy(
- IRBuilder& builder,
- IRInst* val1,
- IRStructType* type1,
- IRInst* val2,
- IRStructType* type2,
- MapStructToFlatStruct& node)
- {
- for (auto& field1Pair : node)
- {
- auto& field1 = field1Pair.second;
-
- // Get member of val1
- IRInst* fieldAddr1 = nullptr;
- if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct)
- {
- fieldAddr1 = builder.emitFieldAddress(type1, val1, field1.getKey());
- }
- else
- {
- if (as<IRPtrTypeBase>(val1))
- val1 = builder.emitLoad(val1);
- fieldAddr1 = builder.emitFieldExtract(type1, val1, field1.getKey());
- }
-
- // If val1 is a struct, recurse
- if (auto fieldAsStruct1 = as<IRStructType>(field1.getFieldType()))
- {
- _emitCopy<copyOptions>(
- builder,
- fieldAddr1,
- fieldAsStruct1,
- val2,
- type2,
- field1);
- continue;
- }
-
- // Get member of val2 which maps to val1.member
- auto field2 = field1.getMapping();
- SLANG_ASSERT(field2);
- IRInst* fieldAddr2 = nullptr;
- if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct)
- {
- if (as<IRPtrTypeBase>(val2))
- val2 = builder.emitLoad(val1);
- fieldAddr2 = builder.emitFieldExtract(type2, val2, field2->getKey());
- }
- else
- {
- fieldAddr2 = builder.emitFieldAddress(type2, val2, field2->getKey());
- }
-
- // Copy val2/val1 member into val1/val2 member
- if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct)
- {
- builder.emitStore(fieldAddr1, fieldAddr2);
- }
- else
- {
- builder.emitStore(fieldAddr2, fieldAddr1);
- }
- }
- }
-
- public:
- void setNode(IRInst* newNode) { node = newNode; }
- // Get 'MapStructToFlatStruct' that is a child of 'parent'.
- // Make 'MapStructToFlatStruct' if no 'member' is currently mapped to 'parent'.
- MapStructToFlatStruct& getMember(IRStructField* member) { return members[member]; }
- MapStructToFlatStruct& operator[](IRStructField* member) { return getMember(member); }
-
- void setMapping(IRStructField* newTargetMapping) { targetMapping = newTargetMapping; }
- // Get 'MapStructToFlatStruct' that is a child of 'parent'.
- // Return nullptr if no member is mapped to 'parent'
- IRStructField* getMapping() { return targetMapping; }
-
- // Copies srcVal into dstVal using hicharchy map.
- template<int copyOptions>
- void emitCopy(IRBuilder& builder, IRInst* dstVal, IRInst* srcVal)
- {
- auto dstType = dstVal->getDataType();
- if (auto dstPtrType = as<IRPtrTypeBase>(dstType))
- dstType = dstPtrType->getValueType();
- auto dstStructType = as<IRStructType>(dstType);
- SLANG_ASSERT(dstStructType);
-
- auto srcType = srcVal->getDataType();
- if (auto srcPtrType = as<IRPtrTypeBase>(srcType))
- srcType = srcPtrType->getValueType();
- auto srcStructType = as<IRStructType>(srcType);
- SLANG_ASSERT(srcStructType);
-
- if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct)
- {
- // CopyOptions::FlatStructIntoStruct copy a flattened-struct (mapped member) into a
- // struct
- SLANG_ASSERT(node == dstStructType);
- _emitCopy<copyOptions>(
- builder,
- dstVal,
- dstStructType,
- srcVal,
- srcStructType,
- *this);
- }
- else
- {
- // CopyOptions::StructIntoFlatStruct copy a struct into a flattened-struct
- SLANG_ASSERT(node == srcStructType);
- _emitCopy<copyOptions>(
- builder,
- srcVal,
- srcStructType,
- dstVal,
- dstStructType,
- *this);
- }
- }
- };
-
- IRStructType* _flattenNestedStructs(
- IRBuilder& builder,
- IRStructType* dst,
- IRStructType* src,
- IRSemanticDecoration* parentSemanticDecoration,
- IRLayoutDecoration* parentLayout,
- MapStructToFlatStruct& mapFieldToField,
- HashSet<IRStructField*>& varsWithSemanticInfo)
- {
- // For all fields ('oldField') of a struct do the following:
- // 1. Check for 'decorations which carry semantic info' (IRSemanticDecoration,
- // IRLayoutDecoration), store these if found.
- // * Do not propagate semantic info if the current node has *any* form of semantic
- // information.
- // Update varsWithSemanticInfo.
- // 2. If IRStructType:
- // 2a. Recurse this function with 'decorations that carry semantic info' from parent.
- // 3. If not IRStructType:
- // 3a. Emit 'newField' equal to 'oldField', add 'decorations which carry semantic info'.
- // 3b. Store a mapping from 'oldField' to 'newField' in 'mapFieldToField'. This info is
- // needed to copy between types.
- for (auto oldField : src->getFields())
- {
- auto& fieldMappingNode = mapFieldToField[oldField];
- fieldMappingNode.setNode(oldField);
-
- // step 1
- bool foundSemanticDecor = false;
- auto oldKey = oldField->getKey();
- IRSemanticDecoration* fieldSemanticDecoration = parentSemanticDecoration;
- if (auto oldSemanticDecoration = oldKey->findDecoration<IRSemanticDecoration>())
- {
- foundSemanticDecor = true;
- fieldSemanticDecoration = oldSemanticDecoration;
- parentLayout = nullptr;
- }
-
- IRLayoutDecoration* fieldLayout = parentLayout;
- if (auto oldLayout = oldKey->findDecoration<IRLayoutDecoration>())
- {
- fieldLayout = oldLayout;
- if (!foundSemanticDecor)
- fieldSemanticDecoration = nullptr;
- }
- if (fieldSemanticDecoration != parentSemanticDecoration || parentLayout != fieldLayout)
- varsWithSemanticInfo.add(oldField);
-
- // step 2a
- if (auto structFieldType = as<IRStructType>(oldField->getFieldType()))
- {
- _flattenNestedStructs(
- builder,
- dst,
- structFieldType,
- fieldSemanticDecoration,
- fieldLayout,
- fieldMappingNode,
- varsWithSemanticInfo);
- continue;
- }
-
- // step 3a
- auto newKey = builder.createStructKey();
- copyNameHintAndDebugDecorations(newKey, oldKey);
-
- auto newField = builder.createStructField(dst, newKey, oldField->getFieldType());
- copyNameHintAndDebugDecorations(newField, oldField);
-
- if (fieldSemanticDecoration)
- builder.addSemanticDecoration(
- newKey,
- fieldSemanticDecoration->getSemanticName(),
- fieldSemanticDecoration->getSemanticIndex());
-
- if (fieldLayout)
- {
- IRLayout* oldLayout = fieldLayout->getLayout();
- List<IRInst*> instToCopy;
- // Only copy certain decorations needed for resolving system semantics
- for (UInt i = 0; i < oldLayout->getOperandCount(); i++)
- {
- auto operand = oldLayout->getOperand(i);
- if (as<IRVarOffsetAttr>(operand) || as<IRUserSemanticAttr>(operand) ||
- as<IRSystemValueSemanticAttr>(operand) || as<IRStageAttr>(operand))
- instToCopy.add(operand);
- }
- IRVarLayout* newLayout = builder.getVarLayout(instToCopy);
- builder.addLayoutDecoration(newKey, newLayout);
- }
- // step 3b
- fieldMappingNode.setMapping(newField);
- }
-
- return dst;
- }
-
- // Returns a `IRStructType*` without any `IRStructType*` members. `src` may be returned if there
- // was no struct flattening.
- // @param mapFieldToField Behavior maps all `IRStructField` of `src` to the new struct
- // `IRStructFields`s
- IRStructType* maybeFlattenNestedStructs(
- IRBuilder& builder,
- IRStructType* src,
- MapStructToFlatStruct& mapFieldToField,
- HashSet<IRStructField*>& varsWithSemanticInfo)
- {
- // Find all values inside struct that need flattening and legalization.
- bool hasStructTypeMembers = false;
- for (auto field : src->getFields())
- {
- if (as<IRStructType>(field->getFieldType()))
- {
- hasStructTypeMembers = true;
- break;
- }
- }
- if (!hasStructTypeMembers)
- return src;
-
- // We need to:
- // 1. Make new struct 1:1 with old struct but without nestested structs (flatten)
- // 2. Ensure semantic attributes propegate. This will create overlapping semantics (can be
- // handled later).
- // 3. Store the mapping from old to new struct fields to allow copying a old-struct to
- // new-struct.
- builder.setInsertAfter(src);
- auto newStruct = builder.createStructType();
- copyNameHintAndDebugDecorations(newStruct, src);
- mapFieldToField.setNode(src);
- return _flattenNestedStructs(
- builder,
- newStruct,
- src,
- nullptr,
- nullptr,
- mapFieldToField,
- varsWithSemanticInfo);
- }
-
- // Replaces all 'IRReturn' by copying the current 'IRReturn' to a new var of type 'newType'.
- // Copying logic from 'IRReturn' to 'newType' is controlled by 'copyLogicFunc' function.
- template<typename CopyLogicFunc>
- void _replaceAllReturnInst(
- IRBuilder& builder,
- IRFunc* targetFunc,
- IRStructType* newType,
- CopyLogicFunc copyLogicFunc)
- {
- for (auto block : targetFunc->getBlocks())
- {
- if (auto returnInst = as<IRReturn>(block->getTerminator()))
- {
- builder.setInsertBefore(returnInst);
- auto returnVal = returnInst->getVal();
- returnInst->setOperand(0, copyLogicFunc(builder, newType, returnVal));
- }
- }
- }
-
- UInt _returnNonOverlappingAttributeIndex(std::set<UInt>& usedSemanticIndex)
- {
- // Find first unused semantic index of equal semantic type
- // to fill any gaps in user set semantic bindings
- UInt prev = 0;
- for (auto i : usedSemanticIndex)
- {
- if (i > prev + 1)
- {
- break;
- }
- prev = i;
- }
- usedSemanticIndex.insert(prev + 1);
- return prev + 1;
- }
-
- template<typename T>
- struct AttributeParentPair
- {
- IRLayoutDecoration* layoutDecor;
- T* attr;
- };
-
- IRLayoutDecoration* _replaceAttributeOfLayout(
- IRBuilder& builder,
- IRLayoutDecoration* parentLayoutDecor,
- IRInst* instToReplace,
- IRInst* instToReplaceWith)
- {
- // Replace `instToReplace` with a `instToReplaceWith`
-
- auto layout = parentLayoutDecor->getLayout();
- // Find the exact same decoration `instToReplace` in-case multiple of the same type exist
- List<IRInst*> opList;
- opList.add(instToReplaceWith);
- for (UInt i = 0; i < layout->getOperandCount(); i++)
- {
- if (layout->getOperand(i) != instToReplace)
- opList.add(layout->getOperand(i));
- }
- auto newLayoutDecor = builder.addLayoutDecoration(
- parentLayoutDecor->getParent(),
- builder.getVarLayout(opList));
- parentLayoutDecor->removeAndDeallocate();
- return newLayoutDecor;
- }
-
- IRLayoutDecoration* _simplifyUserSemanticNames(
- IRBuilder& builder,
- IRLayoutDecoration* layoutDecor)
- {
- // Ensure all 'ExplicitIndex' semantics such as "SV_TARGET0" are simplified into
- // ("SV_TARGET", 0) using 'IRUserSemanticAttr' This is done to ensure we can check semantic
- // groups using 'IRUserSemanticAttr1->getName() == IRUserSemanticAttr2->getName()'
- SLANG_ASSERT(layoutDecor);
- auto layout = layoutDecor->getLayout();
- List<IRInst*> layoutOps;
- layoutOps.reserve(3);
- bool changed = false;
- for (auto attr : layout->getAllAttrs())
- {
- if (auto userSemantic = as<IRUserSemanticAttr>(attr))
- {
- UnownedStringSlice outName;
- UnownedStringSlice outIndex;
- bool hasStringIndex = splitNameAndIndex(userSemantic->getName(), outName, outIndex);
- if (hasStringIndex)
- {
- changed = true;
- auto loweredName = String(outName).toLower();
- auto loweredNameSlice = loweredName.getUnownedSlice();
- auto newDecoration =
- builder.getUserSemanticAttr(loweredNameSlice, stringToInt(outIndex));
- userSemantic->replaceUsesWith(newDecoration);
- userSemantic->removeAndDeallocate();
- userSemantic = newDecoration;
- }
- layoutOps.add(userSemantic);
- continue;
- }
- layoutOps.add(attr);
- }
- if (changed)
- {
- auto parent = layoutDecor->parent;
- layoutDecor->removeAndDeallocate();
- builder.addLayoutDecoration(parent, builder.getVarLayout(layoutOps));
- }
- return layoutDecor;
- }
- // Find overlapping field semantics and legalize them
- void fixFieldSemanticsOfFlatStruct(IRStructType* structType)
- {
- // Goal is to ensure we do not have overlapping semantics:
- /*
- // Assume the following code
- struct Fragment
- {
- float4 p1 : SV_TARGET;
- float3 p2 : SV_TARGET;
- float2 p3 : SV_TARGET;
- float2 p4 : SV_TARGET;
- };
-
- // Translates into
- struct Fragment
- {
- float4 p1 : SV_TARGET0;
- float3 p2 : SV_TARGET1;
- float2 p3 : SV_TARGET2;
- float2 p4 : SV_TARGET3;
- };
- */
-
- IRBuilder builder(this->m_module);
-
- List<IRSemanticDecoration*> overlappingSemanticsDecor;
- Dictionary<UnownedStringSlice, std::set<UInt, std::less<UInt>>>
- usedSemanticIndexSemanticDecor;
-
- List<AttributeParentPair<IRVarOffsetAttr>> overlappingVarOffset;
- Dictionary<UInt, std::set<UInt, std::less<UInt>>> usedSemanticIndexVarOffset;
-
- List<AttributeParentPair<IRUserSemanticAttr>> overlappingUserSemantic;
- Dictionary<UnownedStringSlice, std::set<UInt, std::less<UInt>>>
- usedSemanticIndexUserSemantic;
-
- // We store a map from old `IRLayoutDecoration*` to new `IRLayoutDecoration*` since when
- // legalizing we may destroy and remake a `IRLayoutDecoration*`
- Dictionary<IRLayoutDecoration*, IRLayoutDecoration*> oldLayoutDecorToNew;
-
- // Collect all "semantic info carrying decorations". Any collected decoration will
- // fill up their respective 'Dictionary<SEMANTIC_TYPE, OrderedHashSet<UInt>>'
- // to keep track of in-use offsets for a semantic type.
- // Example: IRSemanticDecoration with name of "SV_TARGET1".
- // * This will have SEMANTIC_TYPE of "sv_target".
- // * This will use up index '1'
- //
- // Now if a second equal semantic "SV_TARGET1" is found, we add this decoration to
- // a list of 'overlapping semantic info decorations' so we can legalize this
- // 'semantic info decoration' later.
- //
- // NOTE: this is a flat struct, all members are children of the initial
- // IRStructType.
- for (auto field : structType->getFields())
- {
- auto key = field->getKey();
- if (auto semanticDecoration = key->findDecoration<IRSemanticDecoration>())
- {
- // Ensure names are in a uniform lowercase format so we can bunch together simmilar
- // semantics
- UnownedStringSlice outName;
- UnownedStringSlice outIndex;
- bool hasStringIndex =
- splitNameAndIndex(semanticDecoration->getSemanticName(), outName, outIndex);
- if (hasStringIndex)
- {
- auto loweredName = String(outName).toLower();
- auto loweredNameSlice = loweredName.getUnownedSlice();
- auto newDecoration =
- builder.addSemanticDecoration(key, loweredNameSlice, stringToInt(outIndex));
- semanticDecoration->replaceUsesWith(newDecoration);
- semanticDecoration->removeAndDeallocate();
- semanticDecoration = newDecoration;
- }
- auto& semanticUse =
- usedSemanticIndexSemanticDecor[semanticDecoration->getSemanticName()];
- if (semanticUse.find(semanticDecoration->getSemanticIndex()) != semanticUse.end())
- overlappingSemanticsDecor.add(semanticDecoration);
- else
- semanticUse.insert(semanticDecoration->getSemanticIndex());
- }
- if (auto layoutDecor = key->findDecoration<IRLayoutDecoration>())
- {
- // Ensure names are in a uniform lowercase format so we can bunch together simmilar
- // semantics
- layoutDecor = _simplifyUserSemanticNames(builder, layoutDecor);
- oldLayoutDecorToNew[layoutDecor] = layoutDecor;
- auto layout = layoutDecor->getLayout();
- for (auto attr : layout->getAllAttrs())
- {
- if (auto offset = as<IRVarOffsetAttr>(attr))
- {
- auto& semanticUse = usedSemanticIndexVarOffset[offset->getResourceKind()];
- if (semanticUse.find(offset->getOffset()) != semanticUse.end())
- overlappingVarOffset.add({layoutDecor, offset});
- else
- semanticUse.insert(offset->getOffset());
- }
- else if (auto userSemantic = as<IRUserSemanticAttr>(attr))
- {
- auto& semanticUse = usedSemanticIndexUserSemantic[userSemantic->getName()];
- if (semanticUse.find(userSemantic->getIndex()) != semanticUse.end())
- overlappingUserSemantic.add({layoutDecor, userSemantic});
- else
- semanticUse.insert(userSemantic->getIndex());
- }
- }
- }
- }
-
- // Legalize all overlapping 'semantic info decorations'
- for (auto decor : overlappingSemanticsDecor)
- {
- auto newOffset = _returnNonOverlappingAttributeIndex(
- usedSemanticIndexSemanticDecor[decor->getSemanticName()]);
- builder.addSemanticDecoration(
- decor->getParent(),
- decor->getSemanticName(),
- (int)newOffset);
- decor->removeAndDeallocate();
- }
- for (auto& varOffset : overlappingVarOffset)
- {
- auto newOffset = _returnNonOverlappingAttributeIndex(
- usedSemanticIndexVarOffset[varOffset.attr->getResourceKind()]);
- auto newVarOffset = builder.getVarOffsetAttr(
- varOffset.attr->getResourceKind(),
- newOffset,
- varOffset.attr->getSpace());
- oldLayoutDecorToNew[varOffset.layoutDecor] = _replaceAttributeOfLayout(
- builder,
- oldLayoutDecorToNew[varOffset.layoutDecor],
- varOffset.attr,
- newVarOffset);
- }
- for (auto& userSemantic : overlappingUserSemantic)
- {
- auto newOffset = _returnNonOverlappingAttributeIndex(
- usedSemanticIndexUserSemantic[userSemantic.attr->getName()]);
- auto newUserSemantic =
- builder.getUserSemanticAttr(userSemantic.attr->getName(), newOffset);
- oldLayoutDecorToNew[userSemantic.layoutDecor] = _replaceAttributeOfLayout(
- builder,
- oldLayoutDecorToNew[userSemantic.layoutDecor],
- userSemantic.attr,
- newUserSemantic);
- }
- }
-
- void wrapReturnValueInStruct(EntryPointInfo entryPoint)
- {
- // Wrap return value into a struct if it is not already a struct.
- // For example, given this entry point:
- // ```
- // float4 main() : SV_Target { return float3(1,2,3); }
- // ```
- // We are going to transform it into:
- // ```
- // struct Output {
- // float4 value : SV_Target;
- // };
- // Output main() { return {float3(1,2,3)}; }
-
- auto func = entryPoint.entryPointFunc;
-
- auto returnType = func->getResultType();
- if (as<IRVoidType>(returnType))
- return;
- auto entryPointLayoutDecor = func->findDecoration<IRLayoutDecoration>();
- if (!entryPointLayoutDecor)
- return;
- auto entryPointLayout = as<IREntryPointLayout>(entryPointLayoutDecor->getLayout());
- if (!entryPointLayout)
- return;
- auto resultLayout = entryPointLayout->getResultLayout();
-
- // If return type is already a struct, just make sure every field has a semantic.
- if (auto returnStructType = as<IRStructType>(returnType))
- {
- IRBuilder builder(func);
- MapStructToFlatStruct mapOldFieldToNewField;
- // Flatten result struct type to ensure we do not have nested semantics
- auto flattenedStruct = maybeFlattenNestedStructs(
- builder,
- returnStructType,
- mapOldFieldToNewField,
- semanticInfoToRemove);
- if (returnStructType != flattenedStruct)
- {
- // Replace all return-values with the flattenedStruct we made.
- _replaceAllReturnInst(
- builder,
- func,
- flattenedStruct,
- [&](IRBuilder& copyBuilder, IRStructType* dstType, IRInst* srcVal) -> IRInst*
- {
- auto srcStructType = as<IRStructType>(srcVal->getDataType());
- SLANG_ASSERT(srcStructType);
- auto dstVal = copyBuilder.emitVar(dstType);
- mapOldFieldToNewField.emitCopy<(
- int)MapStructToFlatStruct::CopyOptions::StructIntoFlatStruct>(
- copyBuilder,
- dstVal,
- srcVal);
- return builder.emitLoad(dstVal);
- });
- fixUpFuncType(func, flattenedStruct);
- }
- // Ensure non-overlapping semantics
- fixFieldSemanticsOfFlatStruct(flattenedStruct);
- ensureResultStructHasUserSemantic(flattenedStruct, resultLayout);
- return;
- }
-
- IRBuilder builder(func);
- builder.setInsertBefore(func);
- IRStructType* structType = builder.createStructType();
- auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage());
- builder.addNameHintDecoration(
- structType,
- (String(stageText) + toSlice("Output")).getUnownedSlice());
- auto key = builder.createStructKey();
- builder.addNameHintDecoration(key, toSlice("output"));
- builder.addLayoutDecoration(key, resultLayout);
- builder.createStructField(structType, key, returnType);
- IRStructTypeLayout::Builder structTypeLayoutBuilder(&builder);
- structTypeLayoutBuilder.addField(key, resultLayout);
- auto typeLayout = structTypeLayoutBuilder.build();
- IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout);
- auto varLayout = varLayoutBuilder.build();
- ensureResultStructHasUserSemantic(structType, varLayout);
-
- _replaceAllReturnInst(
- builder,
- func,
- structType,
- [](IRBuilder& copyBuilder, IRStructType* dstType, IRInst* srcVal) -> IRInst*
- { return copyBuilder.emitMakeStruct(dstType, 1, &srcVal); });
-
- // Assign an appropriate system value semantic for stage output
- auto stage = entryPoint.entryPointDecor->getProfile().getStage();
- switch (stage)
- {
- case Stage::Compute:
- case Stage::Fragment:
- {
- builder.addTargetSystemValueDecoration(key, toSlice("color(0)"));
- break;
- }
- case Stage::Vertex:
- {
- builder.addTargetSystemValueDecoration(key, toSlice("position"));
- break;
- }
- default:
- SLANG_ASSERT(false);
- return;
- }
-
- fixUpFuncType(func, structType);
- }
-
- void legalizeMeshEntryPoint(EntryPointInfo entryPoint)
- {
- auto func = entryPoint.entryPointFunc;
-
- IRBuilder builder{func->getModule()};
- for (auto param : func->getParams())
- {
- if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration))
- {
- IRVarLayout::Builder varLayoutBuilder(
- &builder,
- IRTypeLayout::Builder{&builder}.build());
-
- varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload);
- auto paramVarLayout = varLayoutBuilder.build();
- builder.addLayoutDecoration(param, paramVarLayout);
-
- IRPtrTypeBase* type = as<IRPtrTypeBase>(param->getDataType());
-
- const auto annotatedPayloadType = builder.getPtrType(
- kIROp_ConstRefType,
- type->getValueType(),
- AddressSpace::MetalObjectData);
-
- param->setFullType(annotatedPayloadType);
- }
- }
- IROutputTopologyDecoration* outputDeco =
- entryPoint.entryPointFunc->findDecoration<IROutputTopologyDecoration>();
- if (outputDeco == nullptr)
- {
- SLANG_UNEXPECTED("Mesh shader output decoration missing");
- return;
- }
- const auto topology = outputDeco->getTopology();
- const auto topStr = topology->getStringSlice();
- UInt topologyEnum = 0;
- if (topStr.caseInsensitiveEquals(toSlice("point")))
- {
- topologyEnum = 1;
- }
- else if (topStr.caseInsensitiveEquals(toSlice("line")))
- {
- topologyEnum = 2;
- }
- else if (topStr.caseInsensitiveEquals(toSlice("triangle")))
- {
- topologyEnum = 3;
- }
- else
- {
- SLANG_UNEXPECTED("unknown topology");
- return;
- }
-
- IRInst* topologyConst = builder.getIntValue(builder.getIntType(), topologyEnum);
-
- IRType* vertexType = nullptr;
- IRType* indicesType = nullptr;
- IRType* primitiveType = nullptr;
-
- IRInst* maxVertices = nullptr;
- IRInst* maxPrimitives = nullptr;
-
- IRInst* verticesParam = nullptr;
- IRInst* indicesParam = nullptr;
- IRInst* primitivesParam = nullptr;
- for (auto param : func->getParams())
- {
- if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration))
- {
- IRVarLayout::Builder varLayoutBuilder(
- &builder,
- IRTypeLayout::Builder{&builder}.build());
-
- varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload);
- auto paramVarLayout = varLayoutBuilder.build();
- builder.addLayoutDecoration(param, paramVarLayout);
- }
- if (param->findDecorationImpl(kIROp_VerticesDecoration))
- {
- auto vertexRefType = as<IRPtrTypeBase>(param->getDataType());
- auto vertexOutputType = as<IRVerticesType>(vertexRefType->getValueType());
- vertexType = vertexOutputType->getElementType();
- maxVertices = vertexOutputType->getMaxElementCount();
- SLANG_ASSERT(vertexType);
-
- verticesParam = param;
- auto vertStruct = as<IRStructType>(vertexType);
- for (auto field : vertStruct->getFields())
- {
- auto key = field->getKey();
- if (auto deco = key->findDecoration<IRSemanticDecoration>())
- {
- if (deco->getSemanticName().caseInsensitiveEquals(toSlice("sv_position")))
- {
- builder.addTargetSystemValueDecoration(key, toSlice("position"));
- }
- }
- }
- }
- if (param->findDecorationImpl(kIROp_IndicesDecoration))
- {
- auto indicesRefType = (IRConstRefType*)param->getDataType();
- auto indicesOutputType = (IRIndicesType*)indicesRefType->getValueType();
- indicesType = indicesOutputType->getElementType();
- maxPrimitives = indicesOutputType->getMaxElementCount();
- SLANG_ASSERT(indicesType);
-
- indicesParam = param;
- }
- if (param->findDecorationImpl(kIROp_PrimitivesDecoration))
- {
- auto primitivesRefType = (IRConstRefType*)param->getDataType();
- auto primitivesOutputType = (IRPrimitivesType*)primitivesRefType->getValueType();
- primitiveType = primitivesOutputType->getElementType();
- SLANG_ASSERT(primitiveType);
-
- primitivesParam = param;
- auto primStruct = as<IRStructType>(primitiveType);
- for (auto field : primStruct->getFields())
- {
- auto key = field->getKey();
- if (auto deco = key->findDecoration<IRSemanticDecoration>())
- {
- if (deco->getSemanticName().caseInsensitiveEquals(
- toSlice("sv_primitiveid")))
- {
- builder.addTargetSystemValueDecoration(key, toSlice("primitive_id"));
- }
- }
- }
- }
- }
- if (primitiveType == nullptr)
- {
- primitiveType = builder.getVoidType();
- }
- builder.setInsertBefore(entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst());
-
- auto meshParam = builder.emitParam(builder.getMetalMeshType(
- vertexType,
- primitiveType,
- maxVertices,
- maxPrimitives,
- topologyConst));
- builder.addExternCppDecoration(meshParam, toSlice("_slang_mesh"));
-
-
- verticesParam->removeFromParent();
- verticesParam->removeAndDeallocate();
-
- indicesParam->removeFromParent();
- indicesParam->removeAndDeallocate();
-
- if (primitivesParam != nullptr)
- {
- primitivesParam->removeFromParent();
- primitivesParam->removeAndDeallocate();
- }
- }
-
- void legalizeDispatchMeshPayloadForMetal(EntryPointInfo entryPoint)
- {
- // Find out DispatchMesh function
- IRGlobalValueWithCode* dispatchMeshFunc = nullptr;
- for (const auto globalInst : entryPoint.entryPointFunc->getModule()->getGlobalInsts())
- {
- if (const auto func = as<IRGlobalValueWithCode>(globalInst))
- {
- if (const auto dec = func->findDecoration<IRKnownBuiltinDecoration>())
- {
- if (dec->getName() == "DispatchMesh")
- {
- SLANG_ASSERT(!dispatchMeshFunc && "Multiple DispatchMesh functions found");
- dispatchMeshFunc = func;
- }
- }
- }
- }
-
- if (!dispatchMeshFunc)
- return;
-
- IRBuilder builder{entryPoint.entryPointFunc->getModule()};
-
- // We'll rewrite the call to use mesh_grid_properties.set_threadgroups_per_grid
- traverseUses(
- dispatchMeshFunc,
- [&](const IRUse* use)
- {
- if (const auto call = as<IRCall>(use->getUser()))
- {
- SLANG_ASSERT(call->getArgCount() == 4);
- const auto payload = call->getArg(3);
-
- const auto payloadPtrType =
- composeGetters<IRPtrType>(payload, &IRInst::getDataType);
- SLANG_ASSERT(payloadPtrType);
- const auto payloadType = payloadPtrType->getValueType();
- SLANG_ASSERT(payloadType);
-
- builder.setInsertBefore(
- entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst());
- const auto annotatedPayloadType = builder.getPtrType(
- kIROp_RefType,
- payloadPtrType->getValueType(),
- AddressSpace::MetalObjectData);
- auto packedParam = builder.emitParam(annotatedPayloadType);
- builder.addExternCppDecoration(packedParam, toSlice("_slang_mesh_payload"));
- IRVarLayout::Builder varLayoutBuilder(
- &builder,
- IRTypeLayout::Builder{&builder}.build());
-
- // Add the MetalPayload resource info, so we can emit [[payload]]
- varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload);
- auto paramVarLayout = varLayoutBuilder.build();
- builder.addLayoutDecoration(packedParam, paramVarLayout);
-
- // Now we replace the call to DispatchMesh with a call to the mesh grid
- // properties But first we need to create the parameter
- const auto meshGridPropertiesType = builder.getMetalMeshGridPropertiesType();
- auto mgp = builder.emitParam(meshGridPropertiesType);
- builder.addExternCppDecoration(mgp, toSlice("_slang_mgp"));
- }
- });
- }
-
- IRInst* tryConvertValue(IRBuilder& builder, IRInst* val, IRType* toType)
- {
- auto fromType = val->getFullType();
- if (auto fromVector = as<IRVectorType>(fromType))
- {
- if (auto toVector = as<IRVectorType>(toType))
- {
- if (fromVector->getElementCount() != toVector->getElementCount())
- {
- fromType = builder.getVectorType(
- fromVector->getElementType(),
- toVector->getElementCount());
- val = builder.emitVectorReshape(fromType, val);
- }
- }
- else if (as<IRBasicType>(toType))
- {
- UInt index = 0;
- val = builder.emitSwizzle(fromVector->getElementType(), val, 1, &index);
- if (toType->getOp() == kIROp_VoidType)
- return nullptr;
- }
- }
- else if (auto fromBasicType = as<IRBasicType>(fromType))
- {
- if (fromBasicType->getOp() == kIROp_VoidType)
- return nullptr;
- if (!as<IRBasicType>(toType))
- return nullptr;
- if (toType->getOp() == kIROp_VoidType)
- return nullptr;
- }
- else
- {
- return nullptr;
- }
- return builder.emitCast(toType, val);
- }
-
- struct SystemValLegalizationWorkItem
- {
- IRInst* var;
- String attrName;
- UInt attrIndex;
- };
-
- std::optional<SystemValLegalizationWorkItem> tryToMakeSystemValWorkItem(IRInst* var)
- {
- if (auto semanticDecoration = var->findDecoration<IRSemanticDecoration>())
- {
- if (semanticDecoration->getSemanticName().startsWithCaseInsensitive(toSlice("sv_")))
- {
- return {
- {var,
- String(semanticDecoration->getSemanticName()).toLower(),
- (UInt)semanticDecoration->getSemanticIndex()}};
- }
- }
-
- auto layoutDecor = var->findDecoration<IRLayoutDecoration>();
- if (!layoutDecor)
- return {};
- auto sysValAttr = layoutDecor->findAttr<IRSystemValueSemanticAttr>();
- if (!sysValAttr)
- return {};
- auto semanticName = String(sysValAttr->getName());
- auto sysAttrIndex = sysValAttr->getIndex();
-
- return {{var, semanticName, sysAttrIndex}};
- }
-
-
- List<SystemValLegalizationWorkItem> collectSystemValFromEntryPoint(EntryPointInfo entryPoint)
- {
- List<SystemValLegalizationWorkItem> systemValWorkItems;
- for (auto param : entryPoint.entryPointFunc->getParams())
- {
- auto maybeWorkItem = tryToMakeSystemValWorkItem(param);
- if (maybeWorkItem.has_value())
- systemValWorkItems.add(std::move(maybeWorkItem.value()));
- }
- return systemValWorkItems;
- }
-
- void legalizeSystemValue(EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem)
- {
- IRBuilder builder(entryPoint.entryPointFunc);
-
- auto var = workItem.var;
- auto semanticName = workItem.attrName;
-
- auto indexAsString = String(workItem.attrIndex);
- auto info = getSystemValueInfo(semanticName, &indexAsString, var);
-
- if (info.isSpecial)
- {
- if (info.metalSystemValueNameEnum == SystemValueSemanticName::InnerCoverage)
- {
- // Metal does not support conservative rasterization, so this is always false.
- auto val = builder.getBoolValue(false);
- var->replaceUsesWith(val);
- var->removeAndDeallocate();
- }
- else if (info.metalSystemValueNameEnum == SystemValueSemanticName::GroupIndex)
- {
- // Ensure we have a cached "sv_groupthreadid" in our entry point
- if (!entryPointToGroupThreadId.containsKey(entryPoint.entryPointFunc))
- {
- auto systemValWorkItems = collectSystemValFromEntryPoint(entryPoint);
- for (auto i : systemValWorkItems)
- {
- auto indexAsStringGroupThreadId = String(i.attrIndex);
- if (getSystemValueInfo(i.attrName, &indexAsStringGroupThreadId, i.var)
- .metalSystemValueNameEnum == SystemValueSemanticName::GroupThreadID)
- {
- entryPointToGroupThreadId[entryPoint.entryPointFunc] = i.var;
- }
- }
- if (!entryPointToGroupThreadId.containsKey(entryPoint.entryPointFunc))
- {
- // Add the missing groupthreadid needed to compute sv_groupindex
- IRBuilder groupThreadIdBuilder(builder);
- groupThreadIdBuilder.setInsertInto(
- entryPoint.entryPointFunc->getFirstBlock());
- auto groupThreadId = groupThreadIdBuilder.emitParamAtHead(
- getGroupThreadIdType(groupThreadIdBuilder));
- entryPointToGroupThreadId[entryPoint.entryPointFunc] = groupThreadId;
- groupThreadIdBuilder.addNameHintDecoration(
- groupThreadId,
- groupThreadIDString);
-
- // Since "sv_groupindex" will be translated out to a global var and no
- // longer be considered a system value we can reuse its layout and semantic
- // info
- Index foundRequiredDecorations = 0;
- IRLayoutDecoration* layoutDecoration = nullptr;
- UInt semanticIndex = 0;
- for (auto decoration : var->getDecorations())
- {
- if (auto layoutDecorationTmp = as<IRLayoutDecoration>(decoration))
- {
- layoutDecoration = layoutDecorationTmp;
- foundRequiredDecorations++;
- }
- else if (auto semanticDecoration = as<IRSemanticDecoration>(decoration))
- {
- semanticIndex = semanticDecoration->getSemanticIndex();
- groupThreadIdBuilder.addSemanticDecoration(
- groupThreadId,
- groupThreadIDString,
- (int)semanticIndex);
- foundRequiredDecorations++;
- }
- if (foundRequiredDecorations >= 2)
- break;
- }
- SLANG_ASSERT(layoutDecoration);
- layoutDecoration->removeFromParent();
- layoutDecoration->insertAtStart(groupThreadId);
- SystemValLegalizationWorkItem newWorkItem = {
- groupThreadId,
- groupThreadIDString,
- semanticIndex};
- legalizeSystemValue(entryPoint, newWorkItem);
- }
- }
-
- IRBuilder svBuilder(builder.getModule());
- svBuilder.setInsertBefore(entryPoint.entryPointFunc->getFirstOrdinaryInst());
- auto uint3Type = builder.getVectorType(
- builder.getUIntType(),
- builder.getIntValue(builder.getIntType(), 3));
- auto computeExtent =
- emitCalcGroupExtents(svBuilder, entryPoint.entryPointFunc, uint3Type);
- if (!computeExtent)
- {
- m_sink->diagnose(
- entryPoint.entryPointFunc,
- Diagnostics::unsupportedSpecializationConstantForNumThreads);
-
- // Fill in placeholder values.
- static const int kAxisCount = 3;
- IRInst* groupExtentAlongAxis[kAxisCount] = {};
- for (int axis = 0; axis < kAxisCount; axis++)
- groupExtentAlongAxis[axis] =
- builder.getIntValue(uint3Type->getElementType(), 1);
- computeExtent =
- builder.emitMakeVector(uint3Type, kAxisCount, groupExtentAlongAxis);
- }
- auto groupIndexCalc = emitCalcGroupIndex(
- svBuilder,
- entryPointToGroupThreadId[entryPoint.entryPointFunc],
- computeExtent);
- svBuilder.addNameHintDecoration(
- groupIndexCalc,
- UnownedStringSlice("sv_groupindex"));
-
- var->replaceUsesWith(groupIndexCalc);
- var->removeAndDeallocate();
- }
- }
- if (info.isUnsupported)
- {
- reportUnsupportedSystemAttribute(var, semanticName);
- return;
- }
- if (!info.permittedTypes.getCount())
- return;
-
- builder.addTargetSystemValueDecoration(var, info.metalSystemValueName.getUnownedSlice());
-
- bool varTypeIsPermitted = false;
- auto varType = var->getFullType();
- for (auto& permittedType : info.permittedTypes)
- {
- varTypeIsPermitted = varTypeIsPermitted || permittedType == varType;
- }
-
- if (!varTypeIsPermitted)
- {
- // Note: we do not currently prefer any conversion
- // example:
- // * allowed types for semantic: `float4`, `uint4`, `int4`
- // * user used, `float2`
- // * Slang will equally prefer `float4` to `uint4` to `int4`.
- // This means the type may lose data if slang selects `uint4` or `int4`.
- bool foundAConversion = false;
- for (auto permittedType : info.permittedTypes)
- {
- var->setFullType(permittedType);
- builder.setInsertBefore(
- entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst());
-
- // get uses before we `tryConvertValue` since this creates a new use
- List<IRUse*> uses;
- for (auto use = var->firstUse; use; use = use->nextUse)
- uses.add(use);
-
- auto convertedValue = tryConvertValue(builder, var, varType);
- if (convertedValue == nullptr)
- continue;
-
- foundAConversion = true;
- copyNameHintAndDebugDecorations(convertedValue, var);
-
- for (auto use : uses)
- builder.replaceOperand(use, convertedValue);
- }
- if (!foundAConversion)
- {
- // If we can't convert the value, report an error.
- for (auto permittedType : info.permittedTypes)
- {
- StringBuilder typeNameSB;
- getTypeNameHint(typeNameSB, permittedType);
- m_sink->diagnose(
- var->sourceLoc,
- Diagnostics::systemValueTypeIncompatible,
- semanticName,
- typeNameSB.produceString());
- }
- }
- }
- }
-
- void legalizeSystemValueParameters(EntryPointInfo entryPoint)
- {
- List<SystemValLegalizationWorkItem> systemValWorkItems =
- collectSystemValFromEntryPoint(entryPoint);
-
- for (auto index = 0; index < systemValWorkItems.getCount(); index++)
- {
- legalizeSystemValue(entryPoint, systemValWorkItems[index]);
- }
- fixUpFuncType(entryPoint.entryPointFunc);
- }
-
- void legalizeEntryPointForMetal(EntryPointInfo entryPoint)
- {
- // Input Parameter Legalize
- depointerizeInputParams(entryPoint.entryPointFunc);
- hoistEntryPointParameterFromStruct(entryPoint);
- packStageInParameters(entryPoint);
- flattenInputParameters(entryPoint);
-
- // System Value Legalize
- legalizeSystemValueParameters(entryPoint);
-
- // Output Value Legalize
- wrapReturnValueInStruct(entryPoint);
-
- // Other Legalize
- switch (entryPoint.entryPointDecor->getProfile().getStage())
- {
- case Stage::Amplification:
- legalizeDispatchMeshPayloadForMetal(entryPoint);
- break;
- case Stage::Mesh:
- legalizeMeshEntryPoint(entryPoint);
- break;
- default:
- break;
- }
- }
-};
-
// metal textures only support writing 4-component values, even if the texture is only 1, 2, or
// 3-component in this case the other channels get ignored, but the signature still doesnt match so
// now we have to replace the value being written with a 4-component vector where the new components
@@ -2187,10 +233,7 @@ void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink)
}
}
- LegalizeMetalEntryPointContext context(sink, module);
- for (auto entryPoint : entryPoints)
- context.legalizeEntryPointForMetal(entryPoint);
- context.removeSemanticLayoutsFromLegalizedStructs();
+ legalizeEntryPointVaryingParamsForMetal(module, sink, entryPoints);
MetalAddressSpaceAssigner metalAddressSpaceAssigner;
specializeAddressSpace(module, &metalAddressSpaceAssigner);
diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp
index effc06f3e..efa028703 100644
--- a/source/slang/slang-ir-wgsl-legalize.cpp
+++ b/source/slang/slang-ir-wgsl-legalize.cpp
@@ -4,1537 +4,169 @@
#include "slang-ir-legalize-binary-operator.h"
#include "slang-ir-legalize-global-values.h"
#include "slang-ir-legalize-varying-params.h"
-#include "slang-ir-util.h"
#include "slang-ir.h"
-#include "slang-parameter-binding.h"
-
-#include <set>
namespace Slang
{
-struct EntryPointInfo
-{
- IRFunc* entryPointFunc;
- IREntryPointDecoration* entryPointDecor;
-};
-
-struct LegalizeWGSLEntryPointContext
+static void legalizeCall(IRCall* call)
{
- HashSet<IRStructField*> semanticInfoToRemove;
- UnownedStringSlice userSemanticName = toSlice("user_semantic");
-
- DiagnosticSink* m_sink;
- IRModule* m_module;
-
- LegalizeWGSLEntryPointContext(DiagnosticSink* sink, IRModule* module)
- : m_sink(sink), m_module(module)
- {
- }
-
- void removeSemanticLayoutsFromLegalizedStructs()
- {
- // WGSL does not allow duplicate attributes to appear in the same shader.
- // If we emit our own struct with `[[color(0)]`, all existing uses of `[[color(0)]]`
- // must be removed.
- for (auto field : semanticInfoToRemove)
- {
- auto key = field->getKey();
- // Some decorations appear twice, destroy all found
- for (;;)
- {
- if (auto semanticDecor = key->findDecoration<IRSemanticDecoration>())
- {
- semanticDecor->removeAndDeallocate();
- continue;
- }
- else if (auto layoutDecor = key->findDecoration<IRLayoutDecoration>())
- {
- layoutDecor->removeAndDeallocate();
- continue;
- }
- break;
- }
- }
- }
-
- // Flattens all struct parameters of an entryPoint to ensure parameters are a flat struct
- void flattenInputParameters(EntryPointInfo entryPoint)
- {
- // Goal is to ensure we have a flattened IRParam (0 nested IRStructType members).
- /*
- // Assume the following code
- struct NestedFragment
- {
- float2 p3;
- };
- struct Fragment
- {
- float4 p1;
- float3 p2;
- NestedFragment p3_nested;
- };
-
- // Fragment flattens into
- struct Fragment
- {
- float4 p1;
- float3 p2;
- float2 p3;
- };
- */
-
- // This is important since WGSL does not allow semantic's on a struct
- /*
- // Assume the following code
- struct NestedFragment1
- {
- float2 p3;
- };
- struct Fragment1
- {
- float4 p1 : SV_TARGET0;
- float3 p2 : SV_TARGET1;
- NestedFragment p3_nested : SV_TARGET2; // error, semantic on struct
- };
-
- */
-
- // Unlike Metal, WGSL does NOT allow semantics on members of a nested struct.
- /*
- // Assume the following code
- struct NestedFragment
- {
- float2 p3;
- };
- struct Fragment
- {
- float4 p1 : SV_TARGET0;
- NestedFragment p2 : SV_TARGET1;
- NestedFragment p3 : SV_TARGET2;
- };
-
- // Legalized with flattening
- struct Fragment
- {
- float4 p1 : SV_TARGET0;
- float2 p2 : SV_TARGET1;
- float2 p3 : SV_TARGET2;
- };
- */
-
- auto func = entryPoint.entryPointFunc;
- bool modified = false;
- for (auto param : func->getParams())
- {
- auto layout = findVarLayout(param);
- if (!layout)
- continue;
- if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput))
- continue;
- if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration))
- continue;
- // If we find a IRParam with a IRStructType member, we need to flatten the entire
- // IRParam
- if (auto structType = as<IRStructType>(param->getDataType()))
- {
- IRBuilder builder(func);
- MapStructToFlatStruct mapOldFieldToNewField;
-
- // Flatten struct if we have nested IRStructType
- auto flattenedStruct = maybeFlattenNestedStructs(
- builder,
- structType,
- mapOldFieldToNewField,
- semanticInfoToRemove);
- // Validate/rearange all semantics which overlap in our flat struct.
- fixFieldSemanticsOfFlatStruct(flattenedStruct);
- ensureStructHasUserSemantic<LayoutResourceKind::VaryingInput>(
- flattenedStruct,
- layout);
- if (flattenedStruct != structType)
- {
- // Replace the 'old IRParam type' with a 'new IRParam type'
- param->setFullType(flattenedStruct);
-
- // Emit a new variable at EntryPoint of 'old IRParam type'
- builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
- auto dstVal = builder.emitVar(structType);
- auto dstLoad = builder.emitLoad(dstVal);
- param->replaceUsesWith(dstLoad);
- builder.setInsertBefore(dstLoad);
- // Copy the 'new IRParam type' to our 'old IRParam type'
- mapOldFieldToNewField
- .emitCopy<(int)MapStructToFlatStruct::CopyOptions::FlatStructIntoStruct>(
- builder,
- dstVal,
- param);
-
- modified = true;
- }
- }
- }
- if (modified)
- fixUpFuncType(func);
- }
-
- struct WGSLSystemValueInfo
- {
- String wgslSystemValueName;
- SystemValueSemanticName wgslSystemValueNameEnum;
- ShortList<IRType*> permittedTypes;
- bool isUnsupported = false;
- WGSLSystemValueInfo()
- {
- // most commonly need 2
- permittedTypes.reserveOverflowBuffer(2);
- }
+ // WGSL does not allow forming a pointer to a sub part of a composite value.
+ // For example, if we have
+ // ```
+ // struct S { float x; float y; };
+ // void foo(inout float v) { v = 1.0f; }
+ // void main() { S s; foo(s.x); }
+ // ```
+ // The call to `foo(s.x)` is illegal in WGSL because `s.x` is a sub part of `s`.
+ // And trying to form `&s.x` in WGSL is illegal.
+ // To work around this, we will create a local variable to hold the sub part of
+ // the composite value.
+ // And then pass the local variable to the function.
+ // After the call, we will write back the local variable to the sub part of the
+ // composite value.
+ //
+ IRBuilder builder(call);
+ builder.setInsertBefore(call);
+ struct WritebackPair
+ {
+ IRInst* dest;
+ IRInst* value;
};
+ ShortList<WritebackPair> pendingWritebacks;
- WGSLSystemValueInfo getSystemValueInfo(
- String inSemanticName,
- String* optionalSemanticIndex,
- IRInst* parentVar)
+ for (UInt i = 0; i < call->getArgCount(); i++)
{
- IRBuilder builder(m_module);
- WGSLSystemValueInfo result = {};
- UnownedStringSlice semanticName;
- UnownedStringSlice semanticIndex;
-
- auto hasExplicitIndex =
- splitNameAndIndex(inSemanticName.getUnownedSlice(), semanticName, semanticIndex);
- if (!hasExplicitIndex && optionalSemanticIndex)
- semanticIndex = optionalSemanticIndex->getUnownedSlice();
-
- result.wgslSystemValueNameEnum = convertSystemValueSemanticNameToEnum(semanticName);
-
- switch (result.wgslSystemValueNameEnum)
+ auto arg = call->getArg(i);
+ auto ptrType = as<IRPtrTypeBase>(arg->getDataType());
+ if (!ptrType)
+ continue;
+ switch (arg->getOp())
{
-
- case SystemValueSemanticName::CullDistance:
- {
- result.isUnsupported = true;
- }
- break;
-
- case SystemValueSemanticName::ClipDistance:
- {
- // TODO: Implement this based on the 'clip-distances' feature in WGSL
- // https: // www.w3.org/TR/webgpu/#dom-gpufeaturename-clip-distances
- result.isUnsupported = true;
- }
- break;
-
- case SystemValueSemanticName::Coverage:
- {
- result.wgslSystemValueName = toSlice("sample_mask");
- result.permittedTypes.add(builder.getUIntType());
- }
- break;
-
- case SystemValueSemanticName::Depth:
- {
- result.wgslSystemValueName = toSlice("frag_depth");
- result.permittedTypes.add(builder.getBasicType(BaseType::Float));
- }
- break;
-
- case SystemValueSemanticName::DepthGreaterEqual:
- case SystemValueSemanticName::DepthLessEqual:
- {
- result.isUnsupported = true;
- }
- break;
-
- case SystemValueSemanticName::DispatchThreadID:
- {
- result.wgslSystemValueName = toSlice("global_invocation_id");
- result.permittedTypes.add(builder.getVectorType(
- builder.getBasicType(BaseType::UInt),
- builder.getIntValue(builder.getIntType(), 3)));
- }
- break;
-
- case SystemValueSemanticName::DomainLocation:
- {
- result.isUnsupported = true;
- }
- break;
-
- case SystemValueSemanticName::GroupID:
- {
- result.wgslSystemValueName = toSlice("workgroup_id");
- result.permittedTypes.add(builder.getVectorType(
- builder.getBasicType(BaseType::UInt),
- builder.getIntValue(builder.getIntType(), 3)));
- }
- break;
-
- case SystemValueSemanticName::GroupIndex:
- {
- result.wgslSystemValueName = toSlice("local_invocation_index");
- result.permittedTypes.add(builder.getUIntType());
- }
- break;
-
- case SystemValueSemanticName::GroupThreadID:
- {
- result.wgslSystemValueName = toSlice("local_invocation_id");
- result.permittedTypes.add(builder.getVectorType(
- builder.getBasicType(BaseType::UInt),
- builder.getIntValue(builder.getIntType(), 3)));
- }
- break;
-
- case SystemValueSemanticName::GSInstanceID:
- {
- // No Geometry shaders in WGSL
- result.isUnsupported = true;
- }
- break;
-
- case SystemValueSemanticName::InnerCoverage:
- {
- result.isUnsupported = true;
- }
- break;
-
- case SystemValueSemanticName::InstanceID:
- {
- result.wgslSystemValueName = toSlice("instance_index");
- result.permittedTypes.add(builder.getUIntType());
- }
- break;
-
- case SystemValueSemanticName::IsFrontFace:
- {
- result.wgslSystemValueName = toSlice("front_facing");
- result.permittedTypes.add(builder.getBoolType());
- }
- break;
-
- case SystemValueSemanticName::OutputControlPointID:
- case SystemValueSemanticName::PointSize:
- {
- result.isUnsupported = true;
- }
- break;
-
- case SystemValueSemanticName::Position:
- {
- result.wgslSystemValueName = toSlice("position");
- result.permittedTypes.add(builder.getVectorType(
- builder.getBasicType(BaseType::Float),
- builder.getIntValue(builder.getIntType(), 4)));
- break;
- }
-
- case SystemValueSemanticName::PrimitiveID:
- case SystemValueSemanticName::RenderTargetArrayIndex:
- {
- result.isUnsupported = true;
- break;
- }
-
- case SystemValueSemanticName::SampleIndex:
- {
- result.wgslSystemValueName = toSlice("sample_index");
- result.permittedTypes.add(builder.getUIntType());
- break;
- }
-
- case SystemValueSemanticName::StencilRef:
- case SystemValueSemanticName::Target:
- case SystemValueSemanticName::TessFactor:
- {
- result.isUnsupported = true;
- break;
- }
-
- case SystemValueSemanticName::VertexID:
- {
- result.wgslSystemValueName = toSlice("vertex_index");
- result.permittedTypes.add(builder.getUIntType());
- break;
- }
-
- case SystemValueSemanticName::ViewID:
- case SystemValueSemanticName::ViewportArrayIndex:
- case SystemValueSemanticName::StartVertexLocation:
- case SystemValueSemanticName::StartInstanceLocation:
- {
- result.isUnsupported = true;
- break;
- }
-
+ case kIROp_Var:
+ case kIROp_Param:
+ case kIROp_GlobalParam:
+ case kIROp_GlobalVar:
+ continue;
default:
- {
- m_sink->diagnose(
- parentVar,
- Diagnostics::unimplementedSystemValueSemantic,
- semanticName);
- return result;
- }
+ break;
}
- return result;
- }
+ // Create a local variable to hold the input argument.
+ auto var = builder.emitVar(ptrType->getValueType(), AddressSpace::Function);
- void reportUnsupportedSystemAttribute(IRInst* param, String semanticName)
- {
- m_sink->diagnose(
- param->sourceLoc,
- Diagnostics::systemValueAttributeNotSupported,
- semanticName);
+ // Store the input argument into the local variable.
+ builder.emitStore(var, builder.emitLoad(arg));
+ builder.replaceOperand(call->getArgs() + i, var);
+ pendingWritebacks.add({arg, var});
}
- template<LayoutResourceKind K>
- void ensureStructHasUserSemantic(IRStructType* structType, IRVarLayout* varLayout)
+ // Perform writebacks after the call.
+ builder.setInsertAfter(call);
+ for (auto& pair : pendingWritebacks)
{
- // Ensure each field in an output struct type has either a system semantic or a user
- // semantic, so that signature matching can happen correctly.
- auto typeLayout = as<IRStructTypeLayout>(varLayout->getTypeLayout());
- Index index = 0;
- IRBuilder builder(structType);
- for (auto field : structType->getFields())
- {
- auto key = field->getKey();
- if (auto semanticDecor = key->findDecoration<IRSemanticDecoration>())
- {
- if (semanticDecor->getSemanticName().startsWithCaseInsensitive(toSlice("sv_")))
- {
- auto indexAsString = String(UInt(semanticDecor->getSemanticIndex()));
- auto sysValInfo =
- getSystemValueInfo(semanticDecor->getSemanticName(), &indexAsString, field);
- if (sysValInfo.isUnsupported)
- {
- reportUnsupportedSystemAttribute(field, semanticDecor->getSemanticName());
- }
- else
- {
- builder.addTargetSystemValueDecoration(
- key,
- sysValInfo.wgslSystemValueName.getUnownedSlice());
- semanticDecor->removeAndDeallocate();
- }
- }
- index++;
- continue;
- }
- typeLayout->getFieldLayout(index);
- auto fieldLayout = typeLayout->getFieldLayout(index);
- if (auto offsetAttr = fieldLayout->findOffsetAttr(K))
- {
- UInt varOffset = 0;
- if (auto varOffsetAttr = varLayout->findOffsetAttr(K))
- varOffset = varOffsetAttr->getOffset();
- varOffset += offsetAttr->getOffset();
- builder.addSemanticDecoration(key, toSlice("_slang_attr"), (int)varOffset);
- }
- index++;
- }
- }
-
- // Stores a hicharchy of members and children which map 'oldStruct->member' to
- // 'flatStruct->member' Note: this map assumes we map to FlatStruct since it is easier/faster to
- // process
- struct MapStructToFlatStruct
- {
- /*
- We need a hicharchy map to resolve dependencies for mapping
- oldStruct to newStruct efficently. Example:
-
- MyStruct
- |
- / | \
- / | \
- / | \
- M0<A> M1<A> M2<B>
- | | |
- A_0 A_0 B_0
-
- Without storing hicharchy information, there will be no way to tell apart
- `myStruct.M0.A0` from `myStruct.M1.A0` since IRStructKey/IRStructField
- only has 1 instance of `A::A0`
- */
-
- enum CopyOptions : int
- {
- // Copy a flattened-struct into a struct
- FlatStructIntoStruct = 0,
-
- // Copy a struct into a flattened-struct
- StructIntoFlatStruct = 1,
- };
-
- private:
- // Children of member if applicable.
- Dictionary<IRStructField*, MapStructToFlatStruct> members;
-
- // Field correlating to MapStructToFlatStruct Node.
- IRInst* node;
- IRStructKey* getKey()
- {
- SLANG_ASSERT(as<IRStructField>(node));
- return as<IRStructField>(node)->getKey();
- }
- IRInst* getNode() { return node; }
- IRType* getFieldType()
- {
- SLANG_ASSERT(as<IRStructField>(node));
- return as<IRStructField>(node)->getFieldType();
- }
-
- // Whom node maps to inside target flatStruct
- IRStructField* targetMapping;
-
- auto begin() { return members.begin(); }
- auto end() { return members.end(); }
-
- // Copies members of oldStruct to/from newFlatStruct. Assumes members of val1 maps to
- // members in val2 using `MapStructToFlatStruct`
- template<int copyOptions>
- static void _emitCopy(
- IRBuilder& builder,
- IRInst* val1,
- IRStructType* type1,
- IRInst* val2,
- IRStructType* type2,
- MapStructToFlatStruct& node)
- {
- for (auto& field1Pair : node)
- {
- auto& field1 = field1Pair.second;
-
- // Get member of val1
- IRInst* fieldAddr1 = nullptr;
- if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct)
- {
- fieldAddr1 = builder.emitFieldAddress(type1, val1, field1.getKey());
- }
- else
- {
- if (as<IRPtrTypeBase>(val1))
- val1 = builder.emitLoad(val1);
- fieldAddr1 = builder.emitFieldExtract(type1, val1, field1.getKey());
- }
-
- // If val1 is a struct, recurse
- if (auto fieldAsStruct1 = as<IRStructType>(field1.getFieldType()))
- {
- _emitCopy<copyOptions>(
- builder,
- fieldAddr1,
- fieldAsStruct1,
- val2,
- type2,
- field1);
- continue;
- }
-
- // Get member of val2 which maps to val1.member
- auto field2 = field1.getMapping();
- SLANG_ASSERT(field2);
- IRInst* fieldAddr2 = nullptr;
- if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct)
- {
- if (as<IRPtrTypeBase>(val2))
- val2 = builder.emitLoad(val1);
- fieldAddr2 = builder.emitFieldExtract(type2, val2, field2->getKey());
- }
- else
- {
- fieldAddr2 = builder.emitFieldAddress(type2, val2, field2->getKey());
- }
-
- // Copy val2/val1 member into val1/val2 member
- if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct)
- {
- builder.emitStore(fieldAddr1, fieldAddr2);
- }
- else
- {
- builder.emitStore(fieldAddr2, fieldAddr1);
- }
- }
- }
-
- public:
- void setNode(IRInst* newNode) { node = newNode; }
- // Get 'MapStructToFlatStruct' that is a child of 'parent'.
- // Make 'MapStructToFlatStruct' if no 'member' is currently mapped to 'parent'.
- MapStructToFlatStruct& getMember(IRStructField* member) { return members[member]; }
- MapStructToFlatStruct& operator[](IRStructField* member) { return getMember(member); }
-
- void setMapping(IRStructField* newTargetMapping) { targetMapping = newTargetMapping; }
- // Get 'MapStructToFlatStruct' that is a child of 'parent'.
- // Return nullptr if no member is mapped to 'parent'
- IRStructField* getMapping() { return targetMapping; }
-
- // Copies srcVal into dstVal using hicharchy map.
- template<int copyOptions>
- void emitCopy(IRBuilder& builder, IRInst* dstVal, IRInst* srcVal)
- {
- auto dstType = dstVal->getDataType();
- if (auto dstPtrType = as<IRPtrTypeBase>(dstType))
- dstType = dstPtrType->getValueType();
- auto dstStructType = as<IRStructType>(dstType);
- SLANG_ASSERT(dstStructType);
-
- auto srcType = srcVal->getDataType();
- if (auto srcPtrType = as<IRPtrTypeBase>(srcType))
- srcType = srcPtrType->getValueType();
- auto srcStructType = as<IRStructType>(srcType);
- SLANG_ASSERT(srcStructType);
-
- if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct)
- {
- // CopyOptions::FlatStructIntoStruct copy a flattened-struct (mapped member) into a
- // struct
- SLANG_ASSERT(node == dstStructType);
- _emitCopy<copyOptions>(
- builder,
- dstVal,
- dstStructType,
- srcVal,
- srcStructType,
- *this);
- }
- else
- {
- // CopyOptions::StructIntoFlatStruct copy a struct into a flattened-struct
- SLANG_ASSERT(node == srcStructType);
- _emitCopy<copyOptions>(
- builder,
- srcVal,
- srcStructType,
- dstVal,
- dstStructType,
- *this);
- }
- }
- };
-
- IRStructType* _flattenNestedStructs(
- IRBuilder& builder,
- IRStructType* dst,
- IRStructType* src,
- IRSemanticDecoration* parentSemanticDecoration,
- IRLayoutDecoration* parentLayout,
- MapStructToFlatStruct& mapFieldToField,
- HashSet<IRStructField*>& varsWithSemanticInfo)
- {
- // For all fields ('oldField') of a struct do the following:
- // 1. Check for 'decorations which carry semantic info' (IRSemanticDecoration,
- // IRLayoutDecoration), store these if found.
- // * Do not propagate semantic info if the current node has *any* form of semantic
- // information.
- // Update varsWithSemanticInfo.
- // 2. If IRStructType:
- // 2a. Recurse this function with 'decorations that carry semantic info' from parent.
- // 3. If not IRStructType:
- // 3a. Emit 'newField' with 'newKey' equal to 'oldField' and 'oldKey', respectively,
- // where 'oldKey' is the key corresponding to 'oldField'.
- // Add 'decorations which carry semantic info' to 'newField', and move all decorations
- // of 'oldKey' to 'newKey'.
- // 3b. Store a mapping from 'oldField' to 'newField' in 'mapFieldToField'. This info is
- // needed to copy between types.
- for (auto oldField : src->getFields())
- {
- auto& fieldMappingNode = mapFieldToField[oldField];
- fieldMappingNode.setNode(oldField);
-
- // step 1
- bool foundSemanticDecor = false;
- auto oldKey = oldField->getKey();
- IRSemanticDecoration* fieldSemanticDecoration = parentSemanticDecoration;
- if (auto oldSemanticDecoration = oldKey->findDecoration<IRSemanticDecoration>())
- {
- foundSemanticDecor = true;
- fieldSemanticDecoration = oldSemanticDecoration;
- parentLayout = nullptr;
- }
-
- IRLayoutDecoration* fieldLayout = parentLayout;
- if (auto oldLayout = oldKey->findDecoration<IRLayoutDecoration>())
- {
- fieldLayout = oldLayout;
- if (!foundSemanticDecor)
- fieldSemanticDecoration = nullptr;
- }
- if (fieldSemanticDecoration != parentSemanticDecoration || parentLayout != fieldLayout)
- varsWithSemanticInfo.add(oldField);
-
- // step 2a
- if (auto structFieldType = as<IRStructType>(oldField->getFieldType()))
- {
- _flattenNestedStructs(
- builder,
- dst,
- structFieldType,
- fieldSemanticDecoration,
- fieldLayout,
- fieldMappingNode,
- varsWithSemanticInfo);
- continue;
- }
-
- // step 3a
- auto newKey = builder.createStructKey();
- oldKey->transferDecorationsTo(newKey);
-
- auto newField = builder.createStructField(dst, newKey, oldField->getFieldType());
- copyNameHintAndDebugDecorations(newField, oldField);
-
- if (fieldSemanticDecoration)
- builder.addSemanticDecoration(
- newKey,
- fieldSemanticDecoration->getSemanticName(),
- fieldSemanticDecoration->getSemanticIndex());
-
- if (fieldLayout)
- {
- IRLayout* oldLayout = fieldLayout->getLayout();
- List<IRInst*> instToCopy;
- // Only copy certain decorations needed for resolving system semantics
- for (UInt i = 0; i < oldLayout->getOperandCount(); i++)
- {
- auto operand = oldLayout->getOperand(i);
- if (as<IRVarOffsetAttr>(operand) || as<IRUserSemanticAttr>(operand) ||
- as<IRSystemValueSemanticAttr>(operand) || as<IRStageAttr>(operand))
- instToCopy.add(operand);
- }
- IRVarLayout* newLayout = builder.getVarLayout(instToCopy);
- builder.addLayoutDecoration(newKey, newLayout);
- }
- // step 3b
- fieldMappingNode.setMapping(newField);
- }
-
- return dst;
- }
-
- // Returns a `IRStructType*` without any `IRStructType*` members. `src` may be returned if there
- // was no struct flattening.
- // @param mapFieldToField Behavior maps all `IRStructField` of `src` to the new struct
- // `IRStructFields`s
- IRStructType* maybeFlattenNestedStructs(
- IRBuilder& builder,
- IRStructType* src,
- MapStructToFlatStruct& mapFieldToField,
- HashSet<IRStructField*>& varsWithSemanticInfo)
- {
- // Find all values inside struct that need flattening and legalization.
- bool hasStructTypeMembers = false;
- for (auto field : src->getFields())
- {
- if (as<IRStructType>(field->getFieldType()))
- {
- hasStructTypeMembers = true;
- break;
- }
- }
- if (!hasStructTypeMembers)
- return src;
-
- // We need to:
- // 1. Make new struct 1:1 with old struct but without nestested structs (flatten)
- // 2. Ensure semantic attributes propegate. This will create overlapping semantics (can be
- // handled later).
- // 3. Store the mapping from old to new struct fields to allow copying a old-struct to
- // new-struct.
- builder.setInsertAfter(src);
- auto newStruct = builder.createStructType();
- copyNameHintAndDebugDecorations(newStruct, src);
- mapFieldToField.setNode(src);
- return _flattenNestedStructs(
- builder,
- newStruct,
- src,
- nullptr,
- nullptr,
- mapFieldToField,
- varsWithSemanticInfo);
+ builder.emitStore(pair.dest, builder.emitLoad(pair.value));
}
+}
- // Replaces all 'IRReturn' by copying the current 'IRReturn' to a new var of type 'newType'.
- // Copying logic from 'IRReturn' to 'newType' is controlled by 'copyLogicFunc' function.
- template<typename CopyLogicFunc>
- void _replaceAllReturnInst(
- IRBuilder& builder,
- IRFunc* targetFunc,
- IRStructType* newType,
- CopyLogicFunc copyLogicFunc)
+static void legalizeFunc(IRFunc* func)
+{
+ // Insert casts to convert integer return types
+ auto funcReturnType = func->getResultType();
+ if (isIntegralType(funcReturnType))
{
- for (auto block : targetFunc->getBlocks())
+ for (auto block : func->getBlocks())
{
if (auto returnInst = as<IRReturn>(block->getTerminator()))
{
- builder.setInsertBefore(returnInst);
- auto returnVal = returnInst->getVal();
- returnInst->setOperand(0, copyLogicFunc(builder, newType, returnVal));
- }
- }
- }
-
- UInt _returnNonOverlappingAttributeIndex(std::set<UInt>& usedSemanticIndex)
- {
- // Find first unused semantic index of equal semantic type
- // to fill any gaps in user set semantic bindings
- UInt prev = 0;
- for (auto i : usedSemanticIndex)
- {
- if (i > prev + 1)
- {
- break;
- }
- prev = i;
- }
- usedSemanticIndex.insert(prev + 1);
- return prev + 1;
- }
-
- template<typename T>
- struct AttributeParentPair
- {
- IRLayoutDecoration* layoutDecor;
- T* attr;
- };
-
- IRLayoutDecoration* _replaceAttributeOfLayout(
- IRBuilder& builder,
- IRLayoutDecoration* parentLayoutDecor,
- IRInst* instToReplace,
- IRInst* instToReplaceWith)
- {
- // Replace `instToReplace` with a `instToReplaceWith`
-
- auto layout = parentLayoutDecor->getLayout();
- // Find the exact same decoration `instToReplace` in-case multiple of the same type exist
- List<IRInst*> opList;
- opList.add(instToReplaceWith);
- for (UInt i = 0; i < layout->getOperandCount(); i++)
- {
- if (layout->getOperand(i) != instToReplace)
- opList.add(layout->getOperand(i));
- }
- auto newLayoutDecor = builder.addLayoutDecoration(
- parentLayoutDecor->getParent(),
- builder.getVarLayout(opList));
- parentLayoutDecor->removeAndDeallocate();
- return newLayoutDecor;
- }
-
- IRLayoutDecoration* _simplifyUserSemanticNames(
- IRBuilder& builder,
- IRLayoutDecoration* layoutDecor)
- {
- // Ensure all 'ExplicitIndex' semantics such as "SV_TARGET0" are simplified into
- // ("SV_TARGET", 0) using 'IRUserSemanticAttr' This is done to ensure we can check semantic
- // groups using 'IRUserSemanticAttr1->getName() == IRUserSemanticAttr2->getName()'
- SLANG_ASSERT(layoutDecor);
- auto layout = layoutDecor->getLayout();
- List<IRInst*> layoutOps;
- layoutOps.reserve(3);
- bool changed = false;
- for (auto attr : layout->getAllAttrs())
- {
- if (auto userSemantic = as<IRUserSemanticAttr>(attr))
- {
- UnownedStringSlice outName;
- UnownedStringSlice outIndex;
- bool hasStringIndex = splitNameAndIndex(userSemantic->getName(), outName, outIndex);
-
- changed = true;
- auto newDecoration = builder.getUserSemanticAttr(
- userSemanticName,
- hasStringIndex ? stringToInt(outIndex) : 0);
- userSemantic->replaceUsesWith(newDecoration);
- userSemantic->removeAndDeallocate();
- userSemantic = newDecoration;
-
- layoutOps.add(userSemantic);
- continue;
- }
- layoutOps.add(attr);
- }
- if (changed)
- {
- auto parent = layoutDecor->parent;
- layoutDecor->removeAndDeallocate();
- builder.addLayoutDecoration(parent, builder.getVarLayout(layoutOps));
- }
- return layoutDecor;
- }
-
- // Find overlapping field semantics and legalize them
- void fixFieldSemanticsOfFlatStruct(IRStructType* structType)
- {
- // Goal is to ensure we do not have overlapping semantics for the user defined semantics:
- // Note that in WGSL, the semantics can be either `builtin` without index or `location` with
- // index.
- /*
- // Assume the following code
- struct Fragment
- {
- float4 p0 : SV_POSITION;
- float2 p1 : TEXCOORD0;
- float2 p2 : TEXCOORD1;
- float3 p3 : COLOR0;
- float3 p4 : COLOR1;
- };
-
- // Translates into
- struct Fragment
- {
- float4 p0 : BUILTIN_POSITION;
- float2 p1 : LOCATION_0;
- float2 p2 : LOCATION_1;
- float3 p3 : LOCATION_2;
- float3 p4 : LOCATION_3;
- };
- */
-
- // For Multi-Render-Target, the semantic index must be translated to `location` with
- // the same index. Assume the following code
- /*
- struct Fragment
- {
- float4 p0 : SV_TARGET1;
- float4 p1 : SV_TARGET0;
- };
-
- // Translates into
- struct Fragment
- {
- float4 p0 : LOCATION_1;
- float4 p1 : LOCATION_0;
- };
- */
-
- IRBuilder builder(this->m_module);
-
- List<IRSemanticDecoration*> overlappingSemanticsDecor;
- Dictionary<UnownedStringSlice, std::set<UInt, std::less<UInt>>>
- usedSemanticIndexSemanticDecor;
-
- List<AttributeParentPair<IRVarOffsetAttr>> overlappingVarOffset;
- Dictionary<UInt, std::set<UInt, std::less<UInt>>> usedSemanticIndexVarOffset;
-
- List<AttributeParentPair<IRUserSemanticAttr>> overlappingUserSemantic;
- Dictionary<UnownedStringSlice, std::set<UInt, std::less<UInt>>>
- usedSemanticIndexUserSemantic;
-
- // We store a map from old `IRLayoutDecoration*` to new `IRLayoutDecoration*` since when
- // legalizing we may destroy and remake a `IRLayoutDecoration*`
- Dictionary<IRLayoutDecoration*, IRLayoutDecoration*> oldLayoutDecorToNew;
-
- // Collect all "semantic info carrying decorations". Any collected decoration will
- // fill up their respective 'Dictionary<SEMANTIC_TYPE, OrderedHashSet<UInt>>'
- // to keep track of in-use offsets for a semantic type.
- // Example: IRSemanticDecoration with name of "SV_TARGET1".
- // * This will have SEMANTIC_TYPE of "sv_target".
- // * This will use up index '1'
- //
- // Now if a second equal semantic "SV_TARGET1" is found, we add this decoration to
- // a list of 'overlapping semantic info decorations' so we can legalize this
- // 'semantic info decoration' later.
- //
- // NOTE: this is a flat struct, all members are children of the initial
- // IRStructType.
- for (auto field : structType->getFields())
- {
- auto key = field->getKey();
- if (auto semanticDecoration = key->findDecoration<IRSemanticDecoration>())
- {
- auto semanticName = semanticDecoration->getSemanticName();
-
- // sv_target is treated as a user-semantic because it should be emitted with
- // @location like how the user semantics are emitted.
- // For fragment shader, only sv_target will user @location, and for non-fragment
- // shaders, sv_target is not valid.
- bool isUserSemantic =
- (semanticName.startsWithCaseInsensitive(toSlice("sv_target")) ||
- !semanticName.startsWithCaseInsensitive(toSlice("sv_")));
-
- // Ensure names are in a uniform lowercase format so we can bunch together simmilar
- // semantics.
- UnownedStringSlice outName;
- UnownedStringSlice outIndex;
- bool hasStringIndex = splitNameAndIndex(semanticName, outName, outIndex);
-
- // user semantics gets all same semantic-name.
- auto loweredName = String(outName).toLower();
- auto loweredNameSlice =
- isUserSemantic ? userSemanticName : loweredName.getUnownedSlice();
- auto newDecoration = builder.addSemanticDecoration(
- key,
- loweredNameSlice,
- hasStringIndex ? stringToInt(outIndex) : 0);
- semanticDecoration->replaceUsesWith(newDecoration);
- semanticDecoration->removeAndDeallocate();
- semanticDecoration = newDecoration;
-
- auto& semanticUse =
- usedSemanticIndexSemanticDecor[semanticDecoration->getSemanticName()];
- if (semanticUse.find(semanticDecoration->getSemanticIndex()) != semanticUse.end())
- overlappingSemanticsDecor.add(semanticDecoration);
- else
- semanticUse.insert(semanticDecoration->getSemanticIndex());
- }
- if (auto layoutDecor = key->findDecoration<IRLayoutDecoration>())
- {
- // Ensure names are in a uniform lowercase format so we can bunch together simmilar
- // semantics
- layoutDecor = _simplifyUserSemanticNames(builder, layoutDecor);
- oldLayoutDecorToNew[layoutDecor] = layoutDecor;
- auto layout = layoutDecor->getLayout();
- for (auto attr : layout->getAllAttrs())
- {
- if (auto offset = as<IRVarOffsetAttr>(attr))
- {
- auto& semanticUse = usedSemanticIndexVarOffset[offset->getResourceKind()];
- if (semanticUse.find(offset->getOffset()) != semanticUse.end())
- overlappingVarOffset.add({layoutDecor, offset});
- else
- semanticUse.insert(offset->getOffset());
- }
- else if (auto userSemantic = as<IRUserSemanticAttr>(attr))
- {
- auto& semanticUse = usedSemanticIndexUserSemantic[userSemantic->getName()];
- if (semanticUse.find(userSemantic->getIndex()) != semanticUse.end())
- overlappingUserSemantic.add({layoutDecor, userSemantic});
- else
- semanticUse.insert(userSemantic->getIndex());
- }
- }
- }
- }
-
- // Legalize all overlapping 'semantic info decorations'
- for (auto decor : overlappingSemanticsDecor)
- {
- auto newOffset = _returnNonOverlappingAttributeIndex(
- usedSemanticIndexSemanticDecor[decor->getSemanticName()]);
- builder.addSemanticDecoration(
- decor->getParent(),
- decor->getSemanticName(),
- (int)newOffset);
- decor->removeAndDeallocate();
- }
- for (auto& varOffset : overlappingVarOffset)
- {
- auto newOffset = _returnNonOverlappingAttributeIndex(
- usedSemanticIndexVarOffset[varOffset.attr->getResourceKind()]);
- auto newVarOffset = builder.getVarOffsetAttr(
- varOffset.attr->getResourceKind(),
- newOffset,
- varOffset.attr->getSpace());
- oldLayoutDecorToNew[varOffset.layoutDecor] = _replaceAttributeOfLayout(
- builder,
- oldLayoutDecorToNew[varOffset.layoutDecor],
- varOffset.attr,
- newVarOffset);
- }
- for (auto& userSemantic : overlappingUserSemantic)
- {
- auto newOffset = _returnNonOverlappingAttributeIndex(
- usedSemanticIndexUserSemantic[userSemantic.attr->getName()]);
- auto newUserSemantic =
- builder.getUserSemanticAttr(userSemantic.attr->getName(), newOffset);
- oldLayoutDecorToNew[userSemantic.layoutDecor] = _replaceAttributeOfLayout(
- builder,
- oldLayoutDecorToNew[userSemantic.layoutDecor],
- userSemantic.attr,
- newUserSemantic);
- }
- }
-
- void wrapReturnValueInStruct(EntryPointInfo entryPoint)
- {
- // Wrap return value into a struct if it is not already a struct.
- // For example, given this entry point:
- // ```
- // float4 main() : SV_Target { return float3(1,2,3); }
- // ```
- // We are going to transform it into:
- // ```
- // struct Output {
- // float4 value : SV_Target;
- // };
- // Output main() { return {float3(1,2,3)}; }
-
- auto func = entryPoint.entryPointFunc;
-
- auto returnType = func->getResultType();
- if (as<IRVoidType>(returnType))
- return;
- auto entryPointLayoutDecor = func->findDecoration<IRLayoutDecoration>();
- if (!entryPointLayoutDecor)
- return;
- auto entryPointLayout = as<IREntryPointLayout>(entryPointLayoutDecor->getLayout());
- if (!entryPointLayout)
- return;
- auto resultLayout = entryPointLayout->getResultLayout();
-
- // If return type is already a struct, just make sure every field has a semantic.
- if (auto returnStructType = as<IRStructType>(returnType))
- {
- IRBuilder builder(func);
- MapStructToFlatStruct mapOldFieldToNewField;
- // Flatten result struct type to ensure we do not have nested semantics
- auto flattenedStruct = maybeFlattenNestedStructs(
- builder,
- returnStructType,
- mapOldFieldToNewField,
- semanticInfoToRemove);
- if (returnStructType != flattenedStruct)
- {
- // Replace all return-values with the flattenedStruct we made.
- _replaceAllReturnInst(
- builder,
- func,
- flattenedStruct,
- [&](IRBuilder& copyBuilder, IRStructType* dstType, IRInst* srcVal) -> IRInst*
- {
- auto srcStructType = as<IRStructType>(srcVal->getDataType());
- SLANG_ASSERT(srcStructType);
- auto dstVal = copyBuilder.emitVar(dstType);
- mapOldFieldToNewField.emitCopy<(
- int)MapStructToFlatStruct::CopyOptions::StructIntoFlatStruct>(
- copyBuilder,
- dstVal,
- srcVal);
- return builder.emitLoad(dstVal);
- });
- fixUpFuncType(func, flattenedStruct);
- }
- // Ensure non-overlapping semantics
- fixFieldSemanticsOfFlatStruct(flattenedStruct);
- ensureStructHasUserSemantic<LayoutResourceKind::VaryingOutput>(
- flattenedStruct,
- resultLayout);
- return;
- }
-
- IRBuilder builder(func);
- builder.setInsertBefore(func);
- IRStructType* structType = builder.createStructType();
- auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage());
- builder.addNameHintDecoration(
- structType,
- (String(stageText) + toSlice("Output")).getUnownedSlice());
- auto key = builder.createStructKey();
- builder.addNameHintDecoration(key, toSlice("output"));
- builder.addLayoutDecoration(key, resultLayout);
- builder.createStructField(structType, key, returnType);
- IRStructTypeLayout::Builder structTypeLayoutBuilder(&builder);
- structTypeLayoutBuilder.addField(key, resultLayout);
- auto typeLayout = structTypeLayoutBuilder.build();
- IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout);
- auto varLayout = varLayoutBuilder.build();
- ensureStructHasUserSemantic<LayoutResourceKind::VaryingOutput>(structType, varLayout);
-
- _replaceAllReturnInst(
- builder,
- func,
- structType,
- [](IRBuilder& copyBuilder, IRStructType* dstType, IRInst* srcVal) -> IRInst*
- { return copyBuilder.emitMakeStruct(dstType, 1, &srcVal); });
-
- // Assign an appropriate system value semantic for stage output
- auto stage = entryPoint.entryPointDecor->getProfile().getStage();
- switch (stage)
- {
- case Stage::Compute:
- case Stage::Fragment:
- {
- IRInst* operands[] = {
- builder.getStringValue(userSemanticName),
- builder.getIntValue(builder.getIntType(), 0)};
- builder.addDecoration(
- key,
- kIROp_SemanticDecoration,
- operands,
- SLANG_COUNT_OF(operands));
- break;
- }
- case Stage::Vertex:
- {
- builder.addTargetSystemValueDecoration(key, toSlice("position"));
- break;
- }
- default:
- SLANG_ASSERT(false);
- return;
- }
-
- fixUpFuncType(func, structType);
- }
-
- IRInst* tryConvertValue(IRBuilder& builder, IRInst* val, IRType* toType)
- {
- auto fromType = val->getFullType();
- if (auto fromVector = as<IRVectorType>(fromType))
- {
- if (auto toVector = as<IRVectorType>(toType))
- {
- if (fromVector->getElementCount() != toVector->getElementCount())
- {
- fromType = builder.getVectorType(
- fromVector->getElementType(),
- toVector->getElementCount());
- val = builder.emitVectorReshape(fromType, val);
- }
- }
- else if (as<IRBasicType>(toType))
- {
- UInt index = 0;
- val = builder.emitSwizzle(fromVector->getElementType(), val, 1, &index);
- if (toType->getOp() == kIROp_VoidType)
- return nullptr;
- }
- }
- else if (auto fromBasicType = as<IRBasicType>(fromType))
- {
- if (fromBasicType->getOp() == kIROp_VoidType)
- return nullptr;
- if (!as<IRBasicType>(toType))
- return nullptr;
- if (toType->getOp() == kIROp_VoidType)
- return nullptr;
- }
- else
- {
- return nullptr;
- }
- return builder.emitCast(toType, val);
- }
-
- struct SystemValLegalizationWorkItem
- {
- IRInst* var;
- IRType* varType;
- String attrName;
- UInt attrIndex;
- };
-
- std::optional<SystemValLegalizationWorkItem> tryToMakeSystemValWorkItem(
- IRInst* var,
- IRType* varType)
- {
- if (auto semanticDecoration = var->findDecoration<IRSemanticDecoration>())
- {
- if (semanticDecoration->getSemanticName().startsWithCaseInsensitive(toSlice("sv_")))
- {
- return {
- {var,
- varType,
- String(semanticDecoration->getSemanticName()).toLower(),
- (UInt)semanticDecoration->getSemanticIndex()}};
- }
- }
-
- auto layoutDecor = var->findDecoration<IRLayoutDecoration>();
- if (!layoutDecor)
- return {};
- auto sysValAttr = layoutDecor->findAttr<IRSystemValueSemanticAttr>();
- if (!sysValAttr)
- return {};
- auto semanticName = String(sysValAttr->getName());
- auto sysAttrIndex = sysValAttr->getIndex();
-
- return {{var, varType, semanticName, sysAttrIndex}};
- }
-
- List<SystemValLegalizationWorkItem> collectSystemValFromEntryPoint(EntryPointInfo entryPoint)
- {
- List<SystemValLegalizationWorkItem> systemValWorkItems;
- for (auto param : entryPoint.entryPointFunc->getParams())
- {
- if (auto structType = as<IRStructType>(param->getDataType()))
- {
- for (auto field : structType->getFields())
- {
- // Nested struct-s are flattened already by flattenInputParameters().
- SLANG_ASSERT(!as<IRStructType>(field->getFieldType()));
-
- auto key = field->getKey();
- auto fieldType = field->getFieldType();
- auto maybeWorkItem = tryToMakeSystemValWorkItem(key, fieldType);
- if (maybeWorkItem.has_value())
- systemValWorkItems.add(std::move(maybeWorkItem.value()));
- }
- continue;
- }
-
- auto maybeWorkItem = tryToMakeSystemValWorkItem(param, param->getFullType());
- if (maybeWorkItem.has_value())
- systemValWorkItems.add(std::move(maybeWorkItem.value()));
- }
- return systemValWorkItems;
- }
-
- void legalizeSystemValue(EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem)
- {
- IRBuilder builder(entryPoint.entryPointFunc);
-
- auto var = workItem.var;
- auto varType = workItem.varType;
- auto semanticName = workItem.attrName;
-
- auto indexAsString = String(workItem.attrIndex);
- auto info = getSystemValueInfo(semanticName, &indexAsString, var);
-
- if (info.isUnsupported)
- {
- reportUnsupportedSystemAttribute(var, semanticName);
- return;
- }
- if (!info.permittedTypes.getCount())
- return;
-
- builder.addTargetSystemValueDecoration(var, info.wgslSystemValueName.getUnownedSlice());
-
- bool varTypeIsPermitted = false;
- for (auto& permittedType : info.permittedTypes)
- {
- varTypeIsPermitted = varTypeIsPermitted || permittedType == varType;
- }
-
- if (!varTypeIsPermitted)
- {
- // Note: we do not currently prefer any conversion
- // example:
- // * allowed types for semantic: `float4`, `uint4`, `int4`
- // * user used, `float2`
- // * Slang will equally prefer `float4` to `uint4` to `int4`.
- // This means the type may lose data if slang selects `uint4` or `int4`.
- bool foundAConversion = false;
- for (auto permittedType : info.permittedTypes)
- {
- var->setFullType(permittedType);
- builder.setInsertBefore(
- entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst());
-
- // get uses before we `tryConvertValue` since this creates a new use
- List<IRUse*> uses;
- for (auto use = var->firstUse; use; use = use->nextUse)
- uses.add(use);
-
- auto convertedValue = tryConvertValue(builder, var, varType);
- if (convertedValue == nullptr)
- continue;
-
- foundAConversion = true;
- copyNameHintAndDebugDecorations(convertedValue, var);
-
- for (auto use : uses)
- builder.replaceOperand(use, convertedValue);
- }
- if (!foundAConversion)
- {
- // If we can't convert the value, report an error.
- for (auto permittedType : info.permittedTypes)
+ auto returnedValue = returnInst->getOperand(0);
+ auto returnedValueType = returnedValue->getDataType();
+ if (isIntegralType(returnedValueType))
{
- StringBuilder typeNameSB;
- getTypeNameHint(typeNameSB, permittedType);
- m_sink->diagnose(
- var->sourceLoc,
- Diagnostics::systemValueTypeIncompatible,
- semanticName,
- typeNameSB.produceString());
+ IRBuilder builder(returnInst);
+ builder.setInsertBefore(returnInst);
+ auto newOp = builder.emitCast(funcReturnType, returnedValue);
+ builder.replaceOperand(returnInst->getOperands(), newOp);
}
}
}
}
+}
- void legalizeSystemValueParameters(EntryPointInfo entryPoint)
- {
- List<SystemValLegalizationWorkItem> systemValWorkItems =
- collectSystemValFromEntryPoint(entryPoint);
-
- for (auto index = 0; index < systemValWorkItems.getCount(); index++)
- {
- legalizeSystemValue(entryPoint, systemValWorkItems[index]);
- }
- fixUpFuncType(entryPoint.entryPointFunc);
- }
-
- void legalizeEntryPointForWGSL(EntryPointInfo entryPoint)
- {
- // If the entrypoint is receiving varying inputs as a pointer, turn it into a value.
- depointerizeInputParams(entryPoint.entryPointFunc);
-
- // Input Parameter Legalize
- flattenInputParameters(entryPoint);
-
- // System Value Legalize
- legalizeSystemValueParameters(entryPoint);
-
- // Output Value Legalize
- wrapReturnValueInStruct(entryPoint);
- }
-
- void legalizeCall(IRCall* call)
- {
- // WGSL does not allow forming a pointer to a sub part of a composite value.
- // For example, if we have
- // ```
- // struct S { float x; float y; };
- // void foo(inout float v) { v = 1.0f; }
- // void main() { S s; foo(s.x); }
- // ```
- // The call to `foo(s.x)` is illegal in WGSL because `s.x` is a sub part of `s`.
- // And trying to form `&s.x` in WGSL is illegal.
- // To work around this, we will create a local variable to hold the sub part of
- // the composite value.
- // And then pass the local variable to the function.
- // After the call, we will write back the local variable to the sub part of the
- // composite value.
- //
- IRBuilder builder(call);
- builder.setInsertBefore(call);
- struct WritebackPair
- {
- IRInst* dest;
- IRInst* value;
- };
- ShortList<WritebackPair> pendingWritebacks;
-
- for (UInt i = 0; i < call->getArgCount(); i++)
- {
- auto arg = call->getArg(i);
- auto ptrType = as<IRPtrTypeBase>(arg->getDataType());
- if (!ptrType)
- continue;
- switch (arg->getOp())
- {
- case kIROp_Var:
- case kIROp_Param:
- case kIROp_GlobalParam:
- case kIROp_GlobalVar:
- continue;
- default:
- break;
- }
-
- // Create a local variable to hold the input argument.
- auto var = builder.emitVar(ptrType->getValueType(), AddressSpace::Function);
-
- // Store the input argument into the local variable.
- builder.emitStore(var, builder.emitLoad(arg));
- builder.replaceOperand(call->getArgs() + i, var);
- pendingWritebacks.add({arg, var});
- }
-
- // Perform writebacks after the call.
- builder.setInsertAfter(call);
- for (auto& pair : pendingWritebacks)
- {
- builder.emitStore(pair.dest, builder.emitLoad(pair.value));
- }
- }
-
- void legalizeFunc(IRFunc* func)
- {
- // Insert casts to convert integer return types
- auto funcReturnType = func->getResultType();
- if (isIntegralType(funcReturnType))
- {
- for (auto block : func->getBlocks())
- {
- if (auto returnInst = as<IRReturn>(block->getTerminator()))
- {
- auto returnedValue = returnInst->getOperand(0);
- auto returnedValueType = returnedValue->getDataType();
- if (isIntegralType(returnedValueType))
- {
- IRBuilder builder(returnInst);
- builder.setInsertBefore(returnInst);
- auto newOp = builder.emitCast(funcReturnType, returnedValue);
- builder.replaceOperand(returnInst->getOperands(), newOp);
- }
- }
- }
- }
- }
-
- void legalizeSwitch(IRSwitch* switchInst)
- {
- // WGSL Requires all switch statements to contain a default case.
- // If the switch statement does not contain a default case, we will add one.
- if (switchInst->getDefaultLabel() != switchInst->getBreakLabel())
- return;
- IRBuilder builder(switchInst);
- auto defaultBlock = builder.createBlock();
- builder.setInsertInto(defaultBlock);
- builder.emitBranch(switchInst->getBreakLabel());
- defaultBlock->insertBefore(switchInst->getBreakLabel());
- List<IRInst*> cases;
- for (UInt i = 0; i < switchInst->getCaseCount(); i++)
- {
- cases.add(switchInst->getCaseValue(i));
- cases.add(switchInst->getCaseLabel(i));
- }
- builder.setInsertBefore(switchInst);
- auto newSwitch = builder.emitSwitch(
- switchInst->getCondition(),
- switchInst->getBreakLabel(),
- defaultBlock,
- (UInt)cases.getCount(),
- cases.getBuffer());
- switchInst->transferDecorationsTo(newSwitch);
- switchInst->removeAndDeallocate();
- }
-
- void processInst(IRInst* inst)
- {
- switch (inst->getOp())
- {
- case kIROp_Call:
- legalizeCall(static_cast<IRCall*>(inst));
- break;
-
- case kIROp_Switch:
- legalizeSwitch(as<IRSwitch>(inst));
- break;
-
- // For all binary operators, make sure both side of the operator have the same type
- // (vector-ness and matrix-ness).
- case kIROp_Add:
- case kIROp_Sub:
- case kIROp_Mul:
- case kIROp_Div:
- case kIROp_FRem:
- case kIROp_IRem:
- case kIROp_And:
- case kIROp_Or:
- case kIROp_BitAnd:
- case kIROp_BitOr:
- case kIROp_BitXor:
- case kIROp_Lsh:
- case kIROp_Rsh:
- case kIROp_Eql:
- case kIROp_Neq:
- case kIROp_Greater:
- case kIROp_Less:
- case kIROp_Geq:
- case kIROp_Leq:
- legalizeBinaryOp(inst);
- break;
+static void legalizeSwitch(IRSwitch* switchInst)
+{
+ // WGSL Requires all switch statements to contain a default case.
+ // If the switch statement does not contain a default case, we will add one.
+ if (switchInst->getDefaultLabel() != switchInst->getBreakLabel())
+ return;
+ IRBuilder builder(switchInst);
+ auto defaultBlock = builder.createBlock();
+ builder.setInsertInto(defaultBlock);
+ builder.emitBranch(switchInst->getBreakLabel());
+ defaultBlock->insertBefore(switchInst->getBreakLabel());
+ List<IRInst*> cases;
+ for (UInt i = 0; i < switchInst->getCaseCount(); i++)
+ {
+ cases.add(switchInst->getCaseValue(i));
+ cases.add(switchInst->getCaseLabel(i));
+ }
+ builder.setInsertBefore(switchInst);
+ auto newSwitch = builder.emitSwitch(
+ switchInst->getCondition(),
+ switchInst->getBreakLabel(),
+ defaultBlock,
+ (UInt)cases.getCount(),
+ cases.getBuffer());
+ switchInst->transferDecorationsTo(newSwitch);
+ switchInst->removeAndDeallocate();
+}
- case kIROp_Func:
- legalizeFunc(static_cast<IRFunc*>(inst));
- [[fallthrough]];
- default:
- for (auto child : inst->getModifiableChildren())
- {
- processInst(child);
- }
+static void processInst(IRInst* inst)
+{
+ switch (inst->getOp())
+ {
+ case kIROp_Call:
+ legalizeCall(static_cast<IRCall*>(inst));
+ break;
+
+ case kIROp_Switch:
+ legalizeSwitch(as<IRSwitch>(inst));
+ break;
+
+ // For all binary operators, make sure both side of the operator have the same type
+ // (vector-ness and matrix-ness).
+ case kIROp_Add:
+ case kIROp_Sub:
+ case kIROp_Mul:
+ case kIROp_Div:
+ case kIROp_FRem:
+ case kIROp_IRem:
+ case kIROp_And:
+ case kIROp_Or:
+ case kIROp_BitAnd:
+ case kIROp_BitOr:
+ case kIROp_BitXor:
+ case kIROp_Lsh:
+ case kIROp_Rsh:
+ case kIROp_Eql:
+ case kIROp_Neq:
+ case kIROp_Greater:
+ case kIROp_Less:
+ case kIROp_Geq:
+ case kIROp_Leq:
+ legalizeBinaryOp(inst);
+ break;
+
+ case kIROp_Func:
+ legalizeFunc(static_cast<IRFunc*>(inst));
+ [[fallthrough]];
+ default:
+ for (auto child : inst->getModifiableChildren())
+ {
+ processInst(child);
}
}
-};
+}
struct GlobalInstInliningContext : public GlobalInstInliningContextGeneric
{
@@ -1583,13 +215,10 @@ void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink)
entryPoints.add(info);
}
- LegalizeWGSLEntryPointContext context(sink, module);
- for (auto entryPoint : entryPoints)
- context.legalizeEntryPointForWGSL(entryPoint);
- context.removeSemanticLayoutsFromLegalizedStructs();
+ legalizeEntryPointVaryingParamsForWGSL(module, sink, entryPoints);
// Go through every instruction in the module and legalize them as needed.
- context.processInst(module->getModuleInst());
+ processInst(module->getModuleInst());
// Some global insts are illegal, e.g. function calls.
// We need to inline and remove those.