summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorJay Kwak <82421531+jkwak-work@users.noreply.github.com>2024-11-05 16:31:47 -0800
committerGitHub <noreply@github.com>2024-11-05 16:31:47 -0800
commit79056cd7e0ba261a007e21a98a6f49cb0b032e25 (patch)
treef08c26c9f16ddbfb4a890ce7d201f27d037ccd03 /source
parent4fa76f374c0c35c9c7d186e8addf6861e98baaec (diff)
Legalize the Entry-point for WGSL (#5498)
* Legalize the Entry-point for WGSL The return type of the entry-point needs to be legalized when targeting WGSL. This commit flattens the nested-structs of the return type and the input parameters of the entry-point. Most of code is copied from the legalization code for Metal. The following functions are exactly same to the implementation for Metal or almost same. - flattenInputParameters() : 136 lines - reportUnsupportedSystemAttribute() : 7 lines - ensureResultStructHasUserSemantic() : 46 lines - struct MapStructToFlatStruct : 176 lines - flattenNestedStructs() : 95 lines - maybeFlattenNestedStructs() : 42 lines - _replaceAllReturnInst() : 19 lines - _returnNonOverlappingAttributeIndex() : 16 lines - _replaceAttributeOfLayout() : 23 lines - tryConvertValue() : 41 lines - legalizeSystemValueParameters() : 11 lines They need to be refactored to reduce the duplication later. The test case, `tests/compute/assoctype-lookup.slang`, had a bug that the compute shader was trying to use the varying input/output with the user defined semantics. This commit removes the user defined semantics, because the compute shaders cannot use the user defined semantics. --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-emit-c-like.cpp6
-rw-r--r--source/slang/slang-emit-c-like.h2
-rw-r--r--source/slang/slang-emit-wgsl.cpp26
-rw-r--r--source/slang/slang-emit-wgsl.h3
-rw-r--r--source/slang/slang-ir-wgsl-legalize.cpp1629
5 files changed, 1320 insertions, 346 deletions
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index 5d04a50db..020c31fdc 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -3293,6 +3293,11 @@ void CLikeSourceEmitter::emitSemanticsUsingVarLayout(IRVarLayout* varLayout)
}
}
+void CLikeSourceEmitter::emitSemanticsPrefix(IRInst* inst)
+{
+ emitSemanticsPrefixImpl(inst);
+}
+
void CLikeSourceEmitter::emitSemantics(IRInst* inst, bool allowOffsetLayout)
{
emitSemanticsImpl(inst, allowOffsetLayout);
@@ -3869,6 +3874,7 @@ void CLikeSourceEmitter::emitStructDeclarationsBlock(
emitPackOffsetModifier(fieldKey, fieldType, packOffsetDecoration);
}
}
+ emitSemanticsPrefix(fieldKey);
emitStructFieldAttributes(structType, ff);
emitMemoryQualifiers(fieldKey);
emitType(fieldType, getName(fieldKey));
diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h
index 9f30c4f41..1da3a64dc 100644
--- a/source/slang/slang-emit-c-like.h
+++ b/source/slang/slang-emit-c-like.h
@@ -355,6 +355,7 @@ public:
void diagnoseUnhandledInst(IRInst* inst);
void emitInst(IRInst* inst);
+ void emitSemanticsPrefix(IRInst* inst);
void emitSemantics(IRInst* inst, bool allowOffsets = false);
void emitSemanticsUsingVarLayout(IRVarLayout* varLayout);
@@ -557,6 +558,7 @@ protected:
SLANG_UNUSED(rate);
SLANG_UNUSED(addressSpace);
}
+ virtual void emitSemanticsPrefixImpl(IRInst* inst) { SLANG_UNUSED(inst); }
virtual void emitSemanticsImpl(IRInst* inst, bool allowOffsetLayout)
{
SLANG_UNUSED(inst);
diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp
index dea95c6ec..256697bc7 100644
--- a/source/slang/slang-emit-wgsl.cpp
+++ b/source/slang/slang-emit-wgsl.cpp
@@ -236,6 +236,32 @@ static bool isPowerOf2(const uint32_t n)
return (n != 0U) && ((n - 1U) & n) == 0U;
}
+bool WGSLSourceEmitter::maybeEmitSystemSemantic(IRInst* inst)
+{
+ if (auto sysSemanticDecor = inst->findDecoration<IRTargetSystemValueDecoration>())
+ {
+ m_writer->emit("@builtin(");
+ m_writer->emit(sysSemanticDecor->getSemantic());
+ m_writer->emit(")");
+ return true;
+ }
+ return false;
+}
+
+void WGSLSourceEmitter::emitSemanticsPrefixImpl(IRInst* inst)
+{
+ if (!maybeEmitSystemSemantic(inst))
+ {
+ if (auto semanticDecoration = inst->findDecoration<IRSemanticDecoration>())
+ {
+ m_writer->emit("@location(");
+ m_writer->emit(semanticDecoration->getSemanticIndex());
+ m_writer->emit(")");
+ return;
+ }
+ }
+}
+
void WGSLSourceEmitter::emitStructFieldAttributes(IRStructType* structType, IRStructField* field)
{
// Tint emits errors unless we explicitly spell out the layout in some cases, so emit
diff --git a/source/slang/slang-emit-wgsl.h b/source/slang/slang-emit-wgsl.h
index f178d8f66..6ff9e6786 100644
--- a/source/slang/slang-emit-wgsl.h
+++ b/source/slang/slang-emit-wgsl.h
@@ -38,6 +38,7 @@ public:
virtual void emitParamTypeImpl(IRType* type, const String& name) SLANG_OVERRIDE;
virtual void _emitType(IRType* type, DeclaratorInfo* declarator) SLANG_OVERRIDE;
virtual void emitFrontMatterImpl(TargetRequest* targetReq) SLANG_OVERRIDE;
+ virtual void emitSemanticsPrefixImpl(IRInst* inst) SLANG_OVERRIDE;
virtual void emitStructFieldAttributes(IRStructType* structType, IRStructField* field)
SLANG_OVERRIDE;
virtual void emitCallArg(IRInst* inst) SLANG_OVERRIDE;
@@ -57,6 +58,8 @@ protected:
void ensurePrelude(const char* preludeText);
private:
+ bool maybeEmitSystemSemantic(IRInst* inst);
+
// Emit the matrix type with 'rowCountWGSL' WGSL-rows and 'colCountWGSL' WGSL-columns
void emitMatrixType(
IRType* const elementType,
diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp
index 8ac58780d..6e554a8f8 100644
--- a/source/slang/slang-ir-wgsl-legalize.cpp
+++ b/source/slang/slang-ir-wgsl-legalize.cpp
@@ -6,6 +6,8 @@
#include "slang-ir.h"
#include "slang-parameter-binding.h"
+#include <set>
+
namespace Slang
{
@@ -15,455 +17,1388 @@ struct EntryPointInfo
IREntryPointDecoration* entryPointDecor;
};
-struct SystemValLegalizationWorkItem
+struct LegalizeWGSLEntryPointContext
{
- IRInst* var;
- String attrName;
- UInt attrIndex;
-};
+ HashSet<IRStructField*> semanticInfoToRemove;
+ UnownedStringSlice userSemanticName = toSlice("user_semantic");
-struct WGSLSystemValueInfo
-{
- String wgslSystemValueName;
- SystemValueSemanticName wgslSystemValueNameEnum;
- ShortList<IRType*> permittedTypes;
- bool isUnsupported = false;
-};
+ DiagnosticSink* m_sink;
+ IRModule* m_module;
-struct LegalizeWGSLEntryPointContext
-{
LegalizeWGSLEntryPointContext(DiagnosticSink* sink, IRModule* module)
: m_sink(sink), m_module(module)
{
}
- DiagnosticSink* m_sink;
- IRModule* m_module;
+ // Flattens all struct parameters of an entryPoint to ensure parameters are a flat struct
+ void flattenInputParameters(EntryPointInfo entryPoint)
+ {
+ // Goal is to ensure we have a flattened IRParam (0 nested IRStructType members).
+ /*
+ // Assume the following code
+ struct NestedFragment
+ {
+ float2 p3;
+ };
+ struct Fragment
+ {
+ float4 p1;
+ float3 p2;
+ NestedFragment p3_nested;
+ };
+
+ // Fragment flattens into
+ struct Fragment
+ {
+ float4 p1;
+ float3 p2;
+ float2 p3;
+ };
+ */
+
+ // This is important since WGSL does not allow semantic's on a struct
+ /*
+ // Assume the following code
+ struct NestedFragment1
+ {
+ float2 p3;
+ };
+ struct Fragment1
+ {
+ float4 p1 : SV_TARGET0;
+ float3 p2 : SV_TARGET1;
+ NestedFragment p3_nested : SV_TARGET2; // error, semantic on struct
+ };
+
+ */
+
+ // Unlike Metal, WGSL does NOT allow semantics on members of a nested struct.
+ /*
+ // Assume the following code
+ struct NestedFragment
+ {
+ float2 p3;
+ };
+ struct Fragment
+ {
+ float4 p1 : SV_TARGET0;
+ NestedFragment p2 : SV_TARGET1;
+ NestedFragment p3 : SV_TARGET2;
+ };
+
+ // Legalized with flattening
+ struct Fragment
+ {
+ float4 p1 : SV_TARGET0;
+ float2 p2 : SV_TARGET1;
+ float2 p3 : SV_TARGET2;
+ };
+ */
+
+ auto func = entryPoint.entryPointFunc;
+ bool modified = false;
+ for (auto param : func->getParams())
+ {
+ auto layout = findVarLayout(param);
+ if (!layout)
+ continue;
+ if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput))
+ continue;
+ if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration))
+ continue;
+ // If we find a IRParam with a IRStructType member, we need to flatten the entire
+ // IRParam
+ if (auto structType = as<IRStructType>(param->getDataType()))
+ {
+ IRBuilder builder(func);
+ MapStructToFlatStruct mapOldFieldToNewField;
+
+ // Flatten struct if we have nested IRStructType
+ auto flattenedStruct = maybeFlattenNestedStructs(
+ builder,
+ structType,
+ mapOldFieldToNewField,
+ semanticInfoToRemove);
+ if (flattenedStruct != structType)
+ {
+ // Validate/rearange all semantics which overlap in our flat struct
+ fixFieldSemanticsOfFlatStruct(flattenedStruct);
+
+ // Replace the 'old IRParam type' with a 'new IRParam type'
+ param->setFullType(flattenedStruct);
+
+ // Emit a new variable at EntryPoint of 'old IRParam type'
+ builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
+ auto dstVal = builder.emitVar(structType);
+ auto dstLoad = builder.emitLoad(dstVal);
+ param->replaceUsesWith(dstLoad);
+ builder.setInsertBefore(dstLoad);
+ // Copy the 'new IRParam type' to our 'old IRParam type'
+ mapOldFieldToNewField
+ .emitCopy<(int)MapStructToFlatStruct::CopyOptions::FlatStructIntoStruct>(
+ builder,
+ dstVal,
+ param);
+
+ modified = true;
+ }
+ }
+ }
+ if (modified)
+ fixUpFuncType(func);
+ }
+
+ struct WGSLSystemValueInfo
+ {
+ String wgslSystemValueName;
+ SystemValueSemanticName wgslSystemValueNameEnum;
+ ShortList<IRType*> permittedTypes;
+ bool isUnsupported = false;
+ WGSLSystemValueInfo()
+ {
+ // most commonly need 2
+ permittedTypes.reserveOverflowBuffer(2);
+ }
+ };
- std::optional<SystemValLegalizationWorkItem> makeSystemValWorkItem(IRInst* var);
- void legalizeSystemValue(EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem);
- List<SystemValLegalizationWorkItem> collectSystemValFromEntryPoint(EntryPointInfo entryPoint);
- void legalizeSystemValueParameters(EntryPointInfo entryPoint);
- void legalizeEntryPointForWGSL(EntryPointInfo entryPoint);
- IRInst* tryConvertValue(IRBuilder& builder, IRInst* val, IRType* toType);
WGSLSystemValueInfo getSystemValueInfo(
String inSemanticName,
String* optionalSemanticIndex,
- IRInst* parentVar);
- void legalizeCall(IRCall* call);
- void legalizeSwitch(IRSwitch* switchInst);
- void legalizeBinaryOp(IRInst* inst);
- void processInst(IRInst* inst);
-};
-
-IRInst* LegalizeWGSLEntryPointContext::tryConvertValue(
- IRBuilder& builder,
- IRInst* val,
- IRType* toType)
-{
- auto fromType = val->getFullType();
- if (auto fromVector = as<IRVectorType>(fromType))
+ IRInst* parentVar)
{
- if (auto toVector = as<IRVectorType>(toType))
+ IRBuilder builder(m_module);
+ WGSLSystemValueInfo result = {};
+ UnownedStringSlice semanticName;
+ UnownedStringSlice semanticIndex;
+
+ auto hasExplicitIndex =
+ splitNameAndIndex(inSemanticName.getUnownedSlice(), semanticName, semanticIndex);
+ if (!hasExplicitIndex && optionalSemanticIndex)
+ semanticIndex = optionalSemanticIndex->getUnownedSlice();
+
+ result.wgslSystemValueNameEnum = convertSystemValueSemanticNameToEnum(semanticName);
+
+ switch (result.wgslSystemValueNameEnum)
{
- if (fromVector->getElementCount() != toVector->getElementCount())
+ case SystemValueSemanticName::Position:
{
- fromType = builder.getVectorType(
- fromVector->getElementType(),
- toVector->getElementCount());
- val = builder.emitVectorReshape(fromType, val);
+ result.wgslSystemValueName = toSlice("position");
+ result.permittedTypes.add(builder.getVectorType(
+ builder.getBasicType(BaseType::Float),
+ builder.getIntValue(builder.getIntType(), 4)));
+ break;
+ }
+
+ case SystemValueSemanticName::DispatchThreadID:
+ {
+ result.wgslSystemValueName = toSlice("global_invocation_id");
+ IRType* const vec3uType{builder.getVectorType(
+ builder.getBasicType(BaseType::UInt),
+ builder.getIntValue(builder.getIntType(), 3))};
+ result.permittedTypes.add(vec3uType);
+ }
+ break;
+
+ case SystemValueSemanticName::GroupID:
+ {
+ result.wgslSystemValueName = toSlice("workgroup_id");
+ result.permittedTypes.add(builder.getVectorType(
+ builder.getBasicType(BaseType::UInt),
+ builder.getIntValue(builder.getIntType(), 3)));
+ }
+ break;
+
+ case SystemValueSemanticName::GroupThreadID:
+ {
+ result.wgslSystemValueName = toSlice("local_invocation_id");
+ result.permittedTypes.add(builder.getVectorType(
+ builder.getBasicType(BaseType::UInt),
+ builder.getIntValue(builder.getIntType(), 3)));
+ }
+ break;
+
+ case SystemValueSemanticName::GSInstanceID:
+ {
+ // No Geometry shaders in WGSL
+ result.isUnsupported = true;
+ }
+ break;
+
+ case SystemValueSemanticName::GroupIndex:
+ {
+ result.wgslSystemValueName = toSlice("local_invocation_index");
+ result.permittedTypes.add(builder.getUIntType());
+ }
+ break;
+
+ default:
+ {
+ m_sink->diagnose(
+ parentVar,
+ Diagnostics::unimplementedSystemValueSemantic,
+ semanticName);
+ return result;
}
}
- else if (as<IRBasicType>(toType))
- {
- UInt index = 0;
- val = builder.emitSwizzle(fromVector->getElementType(), val, 1, &index);
- if (toType->getOp() == kIROp_VoidType)
- return nullptr;
- }
+
+ return result;
}
- else if (auto fromBasicType = as<IRBasicType>(fromType))
+
+ void reportUnsupportedSystemAttribute(IRInst* param, String semanticName)
{
- if (fromBasicType->getOp() == kIROp_VoidType)
- return nullptr;
- if (!as<IRBasicType>(toType))
- return nullptr;
- if (toType->getOp() == kIROp_VoidType)
- return nullptr;
+ m_sink->diagnose(
+ param->sourceLoc,
+ Diagnostics::systemValueAttributeNotSupported,
+ semanticName);
}
- else
+
+ void ensureResultStructHasUserSemantic(IRStructType* structType, IRVarLayout* varLayout)
{
- return nullptr;
+ // Ensure each field in an output struct type has either a system semantic or a user
+ // semantic, so that signature matching can happen correctly.
+ auto typeLayout = as<IRStructTypeLayout>(varLayout->getTypeLayout());
+ Index index = 0;
+ IRBuilder builder(structType);
+ for (auto field : structType->getFields())
+ {
+ auto key = field->getKey();
+ if (auto semanticDecor = key->findDecoration<IRSemanticDecoration>())
+ {
+ if (semanticDecor->getSemanticName().startsWithCaseInsensitive(toSlice("sv_")))
+ {
+ auto indexAsString = String(UInt(semanticDecor->getSemanticIndex()));
+ auto sysValInfo =
+ getSystemValueInfo(semanticDecor->getSemanticName(), &indexAsString, field);
+ if (sysValInfo.isUnsupported)
+ {
+ reportUnsupportedSystemAttribute(field, semanticDecor->getSemanticName());
+ }
+ else
+ {
+ builder.addTargetSystemValueDecoration(
+ key,
+ sysValInfo.wgslSystemValueName.getUnownedSlice());
+ semanticDecor->removeAndDeallocate();
+ }
+ }
+ index++;
+ continue;
+ }
+ typeLayout->getFieldLayout(index);
+ auto fieldLayout = typeLayout->getFieldLayout(index);
+ if (auto offsetAttr = fieldLayout->findOffsetAttr(LayoutResourceKind::VaryingOutput))
+ {
+ UInt varOffset = 0;
+ if (auto varOffsetAttr =
+ varLayout->findOffsetAttr(LayoutResourceKind::VaryingOutput))
+ varOffset = varOffsetAttr->getOffset();
+ varOffset += offsetAttr->getOffset();
+ builder.addSemanticDecoration(key, toSlice("_slang_attr"), (int)varOffset);
+ }
+ index++;
+ }
}
- return builder.emitCast(toType, val);
-}
+ // Stores a hicharchy of members and children which map 'oldStruct->member' to
+ // 'flatStruct->member' Note: this map assumes we map to FlatStruct since it is easier/faster to
+ // process
+ struct MapStructToFlatStruct
+ {
+ /*
+ We need a hicharchy map to resolve dependencies for mapping
+ oldStruct to newStruct efficently. Example:
-WGSLSystemValueInfo LegalizeWGSLEntryPointContext::getSystemValueInfo(
- String inSemanticName,
- String* optionalSemanticIndex,
- IRInst* parentVar)
-{
- IRBuilder builder(m_module);
- WGSLSystemValueInfo result = {};
- UnownedStringSlice semanticName;
- UnownedStringSlice semanticIndex;
+ MyStruct
+ |
+ / | \
+ / | \
+ / | \
+ M0<A> M1<A> M2<B>
+ | | |
+ A_0 A_0 B_0
- auto hasExplicitIndex =
- splitNameAndIndex(inSemanticName.getUnownedSlice(), semanticName, semanticIndex);
- if (!hasExplicitIndex && optionalSemanticIndex)
- semanticIndex = optionalSemanticIndex->getUnownedSlice();
+ Without storing hicharchy information, there will be no way to tell apart
+ `myStruct.M0.A0` from `myStruct.M1.A0` since IRStructKey/IRStructField
+ only has 1 instance of `A::A0`
+ */
- result.wgslSystemValueNameEnum = convertSystemValueSemanticNameToEnum(semanticName);
+ enum CopyOptions : int
+ {
+ // Copy a flattened-struct into a struct
+ FlatStructIntoStruct = 0,
- switch (result.wgslSystemValueNameEnum)
- {
+ // Copy a struct into a flattened-struct
+ StructIntoFlatStruct = 1,
+ };
+
+ private:
+ // Children of member if applicable.
+ Dictionary<IRStructField*, MapStructToFlatStruct> members;
- case SystemValueSemanticName::DispatchThreadID:
+ // Field correlating to MapStructToFlatStruct Node.
+ IRInst* node;
+ IRStructKey* getKey()
{
- result.wgslSystemValueName = toSlice("global_invocation_id");
- IRType* const vec3uType{builder.getVectorType(
- builder.getBasicType(BaseType::UInt),
- builder.getIntValue(builder.getIntType(), 3))};
- result.permittedTypes.add(vec3uType);
+ SLANG_ASSERT(as<IRStructField>(node));
+ return as<IRStructField>(node)->getKey();
}
- break;
-
- case SystemValueSemanticName::GroupID:
+ IRInst* getNode() { return node; }
+ IRType* getFieldType()
{
- result.wgslSystemValueName = toSlice("workgroup_id");
- result.permittedTypes.add(builder.getVectorType(
- builder.getBasicType(BaseType::UInt),
- builder.getIntValue(builder.getIntType(), 3)));
+ SLANG_ASSERT(as<IRStructField>(node));
+ return as<IRStructField>(node)->getFieldType();
}
- break;
- case SystemValueSemanticName::GroupThreadID:
+ // Whom node maps to inside target flatStruct
+ IRStructField* targetMapping;
+
+ auto begin() { return members.begin(); }
+ auto end() { return members.end(); }
+
+ // Copies members of oldStruct to/from newFlatStruct. Assumes members of val1 maps to
+ // members in val2 using `MapStructToFlatStruct`
+ template<int copyOptions>
+ static void _emitCopy(
+ IRBuilder& builder,
+ IRInst* val1,
+ IRStructType* type1,
+ IRInst* val2,
+ IRStructType* type2,
+ MapStructToFlatStruct& node)
{
- result.wgslSystemValueName = toSlice("local_invocation_id");
- result.permittedTypes.add(builder.getVectorType(
- builder.getBasicType(BaseType::UInt),
- builder.getIntValue(builder.getIntType(), 3)));
+ for (auto& field1Pair : node)
+ {
+ auto& field1 = field1Pair.second;
+
+ // Get member of val1
+ IRInst* fieldAddr1 = nullptr;
+ if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct)
+ {
+ fieldAddr1 = builder.emitFieldAddress(type1, val1, field1.getKey());
+ }
+ else
+ {
+ if (as<IRPtrTypeBase>(val1))
+ val1 = builder.emitLoad(val1);
+ fieldAddr1 = builder.emitFieldExtract(type1, val1, field1.getKey());
+ }
+
+ // If val1 is a struct, recurse
+ if (auto fieldAsStruct1 = as<IRStructType>(field1.getFieldType()))
+ {
+ _emitCopy<copyOptions>(
+ builder,
+ fieldAddr1,
+ fieldAsStruct1,
+ val2,
+ type2,
+ field1);
+ continue;
+ }
+
+ // Get member of val2 which maps to val1.member
+ auto field2 = field1.getMapping();
+ SLANG_ASSERT(field2);
+ IRInst* fieldAddr2 = nullptr;
+ if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct)
+ {
+ if (as<IRPtrTypeBase>(val2))
+ val2 = builder.emitLoad(val1);
+ fieldAddr2 = builder.emitFieldExtract(type2, val2, field2->getKey());
+ }
+ else
+ {
+ fieldAddr2 = builder.emitFieldAddress(type2, val2, field2->getKey());
+ }
+
+ // Copy val2/val1 member into val1/val2 member
+ if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct)
+ {
+ builder.emitStore(fieldAddr1, fieldAddr2);
+ }
+ else
+ {
+ builder.emitStore(fieldAddr2, fieldAddr1);
+ }
+ }
}
- break;
- case SystemValueSemanticName::GSInstanceID:
+ public:
+ void setNode(IRInst* newNode) { node = newNode; }
+ // Get 'MapStructToFlatStruct' that is a child of 'parent'.
+ // Make 'MapStructToFlatStruct' if no 'member' is currently mapped to 'parent'.
+ MapStructToFlatStruct& getMember(IRStructField* member) { return members[member]; }
+ MapStructToFlatStruct& operator[](IRStructField* member) { return getMember(member); }
+
+ void setMapping(IRStructField* newTargetMapping) { targetMapping = newTargetMapping; }
+ // Get 'MapStructToFlatStruct' that is a child of 'parent'.
+ // Return nullptr if no member is mapped to 'parent'
+ IRStructField* getMapping() { return targetMapping; }
+
+ // Copies srcVal into dstVal using hicharchy map.
+ template<int copyOptions>
+ void emitCopy(IRBuilder& builder, IRInst* dstVal, IRInst* srcVal)
{
- // No Geometry shaders in WGSL
- result.isUnsupported = true;
+ auto dstType = dstVal->getDataType();
+ if (auto dstPtrType = as<IRPtrTypeBase>(dstType))
+ dstType = dstPtrType->getValueType();
+ auto dstStructType = as<IRStructType>(dstType);
+ SLANG_ASSERT(dstStructType);
+
+ auto srcType = srcVal->getDataType();
+ if (auto srcPtrType = as<IRPtrTypeBase>(srcType))
+ srcType = srcPtrType->getValueType();
+ auto srcStructType = as<IRStructType>(srcType);
+ SLANG_ASSERT(srcStructType);
+
+ if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct)
+ {
+ // CopyOptions::FlatStructIntoStruct copy a flattened-struct (mapped member) into a
+ // struct
+ SLANG_ASSERT(node == dstStructType);
+ _emitCopy<copyOptions>(
+ builder,
+ dstVal,
+ dstStructType,
+ srcVal,
+ srcStructType,
+ *this);
+ }
+ else
+ {
+ // CopyOptions::StructIntoFlatStruct copy a struct into a flattened-struct
+ SLANG_ASSERT(node == srcStructType);
+ _emitCopy<copyOptions>(
+ builder,
+ srcVal,
+ srcStructType,
+ dstVal,
+ dstStructType,
+ *this);
+ }
}
- break;
+ };
- case SystemValueSemanticName::GroupIndex:
+ IRStructType* _flattenNestedStructs(
+ IRBuilder& builder,
+ IRStructType* dst,
+ IRStructType* src,
+ IRSemanticDecoration* parentSemanticDecoration,
+ IRLayoutDecoration* parentLayout,
+ MapStructToFlatStruct& mapFieldToField,
+ HashSet<IRStructField*>& varsWithSemanticInfo)
+ {
+ // For all fields ('oldField') of a struct do the following:
+ // 1. Check for 'decorations which carry semantic info' (IRSemanticDecoration,
+ // IRLayoutDecoration), store these if found.
+ // * Do not propagate semantic info if the current node has *any* form of semantic
+ // information.
+ // Update varsWithSemanticInfo.
+ // 2. If IRStructType:
+ // 2a. Recurse this function with 'decorations that carry semantic info' from parent.
+ // 3. If not IRStructType:
+ // 3a. Emit 'newField' equal to 'oldField', add 'decorations which carry semantic info'.
+ // 3b. Store a mapping from 'oldField' to 'newField' in 'mapFieldToField'. This info is
+ // needed to copy between types.
+ for (auto oldField : src->getFields())
{
- result.wgslSystemValueName = toSlice("local_invocation_index");
- result.permittedTypes.add(builder.getUIntType());
+ auto& fieldMappingNode = mapFieldToField[oldField];
+ fieldMappingNode.setNode(oldField);
+
+ // step 1
+ bool foundSemanticDecor = false;
+ auto oldKey = oldField->getKey();
+ IRSemanticDecoration* fieldSemanticDecoration = parentSemanticDecoration;
+ if (auto oldSemanticDecoration = oldKey->findDecoration<IRSemanticDecoration>())
+ {
+ foundSemanticDecor = true;
+ fieldSemanticDecoration = oldSemanticDecoration;
+ parentLayout = nullptr;
+ }
+
+ IRLayoutDecoration* fieldLayout = parentLayout;
+ if (auto oldLayout = oldKey->findDecoration<IRLayoutDecoration>())
+ {
+ fieldLayout = oldLayout;
+ if (!foundSemanticDecor)
+ fieldSemanticDecoration = nullptr;
+ }
+ if (fieldSemanticDecoration != parentSemanticDecoration || parentLayout != fieldLayout)
+ varsWithSemanticInfo.add(oldField);
+
+ // step 2a
+ if (auto structFieldType = as<IRStructType>(oldField->getFieldType()))
+ {
+ _flattenNestedStructs(
+ builder,
+ dst,
+ structFieldType,
+ fieldSemanticDecoration,
+ fieldLayout,
+ fieldMappingNode,
+ varsWithSemanticInfo);
+ continue;
+ }
+
+ // step 3a
+ auto newKey = builder.createStructKey();
+ copyNameHintAndDebugDecorations(newKey, oldKey);
+
+ auto newField = builder.createStructField(dst, newKey, oldField->getFieldType());
+ copyNameHintAndDebugDecorations(newField, oldField);
+
+ if (fieldSemanticDecoration)
+ builder.addSemanticDecoration(
+ newKey,
+ fieldSemanticDecoration->getSemanticName(),
+ fieldSemanticDecoration->getSemanticIndex());
+
+ if (fieldLayout)
+ {
+ IRLayout* oldLayout = fieldLayout->getLayout();
+ List<IRInst*> instToCopy;
+ // Only copy certain decorations needed for resolving system semantics
+ for (UInt i = 0; i < oldLayout->getOperandCount(); i++)
+ {
+ auto operand = oldLayout->getOperand(i);
+ if (as<IRVarOffsetAttr>(operand) || as<IRUserSemanticAttr>(operand) ||
+ as<IRSystemValueSemanticAttr>(operand) || as<IRStageAttr>(operand))
+ instToCopy.add(operand);
+ }
+ IRVarLayout* newLayout = builder.getVarLayout(instToCopy);
+ builder.addLayoutDecoration(newKey, newLayout);
+ }
+ // step 3b
+ fieldMappingNode.setMapping(newField);
}
- break;
- default:
+ return dst;
+ }
+
+ // Returns a `IRStructType*` without any `IRStructType*` members. `src` may be returned if there
+ // was no struct flattening.
+ // @param mapFieldToField Behavior maps all `IRStructField` of `src` to the new struct
+ // `IRStructFields`s
+ IRStructType* maybeFlattenNestedStructs(
+ IRBuilder& builder,
+ IRStructType* src,
+ MapStructToFlatStruct& mapFieldToField,
+ HashSet<IRStructField*>& varsWithSemanticInfo)
+ {
+ // Find all values inside struct that need flattening and legalization.
+ bool hasStructTypeMembers = false;
+ for (auto field : src->getFields())
{
- m_sink->diagnose(
- parentVar,
- Diagnostics::unimplementedSystemValueSemantic,
- semanticName);
- return result;
+ if (as<IRStructType>(field->getFieldType()))
+ {
+ hasStructTypeMembers = true;
+ break;
+ }
}
+ if (!hasStructTypeMembers)
+ return src;
+
+ // We need to:
+ // 1. Make new struct 1:1 with old struct but without nestested structs (flatten)
+ // 2. Ensure semantic attributes propegate. This will create overlapping semantics (can be
+ // handled later).
+ // 3. Store the mapping from old to new struct fields to allow copying a old-struct to
+ // new-struct.
+ builder.setInsertAfter(src);
+ auto newStruct = builder.createStructType();
+ copyNameHintAndDebugDecorations(newStruct, src);
+ mapFieldToField.setNode(src);
+ return _flattenNestedStructs(
+ builder,
+ newStruct,
+ src,
+ nullptr,
+ nullptr,
+ mapFieldToField,
+ varsWithSemanticInfo);
}
- return result;
-}
+ // Replaces all 'IRReturn' by copying the current 'IRReturn' to a new var of type 'newType'.
+ // Copying logic from 'IRReturn' to 'newType' is controlled by 'copyLogicFunc' function.
+ template<typename CopyLogicFunc>
+ void _replaceAllReturnInst(
+ IRBuilder& builder,
+ IRFunc* targetFunc,
+ IRStructType* newType,
+ CopyLogicFunc copyLogicFunc)
+ {
+ for (auto block : targetFunc->getBlocks())
+ {
+ if (auto returnInst = as<IRReturn>(block->getTerminator()))
+ {
+ builder.setInsertBefore(returnInst);
+ auto returnVal = returnInst->getVal();
+ returnInst->setOperand(0, copyLogicFunc(builder, newType, returnVal));
+ }
+ }
+ }
-std::optional<SystemValLegalizationWorkItem> LegalizeWGSLEntryPointContext::makeSystemValWorkItem(
- IRInst* var)
-{
- if (auto semanticDecoration = var->findDecoration<IRSemanticDecoration>())
+ UInt _returnNonOverlappingAttributeIndex(std::set<UInt>& usedSemanticIndex)
{
- bool svPrefix =
- semanticDecoration->getSemanticName().startsWithCaseInsensitive(toSlice("sv_"));
- if (svPrefix)
+ // Find first unused semantic index of equal semantic type
+ // to fill any gaps in user set semantic bindings
+ UInt prev = 0;
+ for (auto i : usedSemanticIndex)
{
- return {
- {var,
- String(semanticDecoration->getSemanticName()).toLower(),
- (UInt)semanticDecoration->getSemanticIndex()}};
+ if (i > prev + 1)
+ {
+ break;
+ }
+ prev = i;
}
+ usedSemanticIndex.insert(prev + 1);
+ return prev + 1;
}
- auto layoutDecor = var->findDecoration<IRLayoutDecoration>();
- if (!layoutDecor)
- return {};
- auto sysValAttr = layoutDecor->findAttr<IRSystemValueSemanticAttr>();
- if (!sysValAttr)
- return {};
- auto semanticName = String(sysValAttr->getName());
- auto sysAttrIndex = sysValAttr->getIndex();
+ template<typename T>
+ struct AttributeParentPair
+ {
+ IRLayoutDecoration* layoutDecor;
+ T* attr;
+ };
- return {{var, semanticName, sysAttrIndex}};
-}
+ IRLayoutDecoration* _replaceAttributeOfLayout(
+ IRBuilder& builder,
+ IRLayoutDecoration* parentLayoutDecor,
+ IRInst* instToReplace,
+ IRInst* instToReplaceWith)
+ {
+ // Replace `instToReplace` with a `instToReplaceWith`
-List<SystemValLegalizationWorkItem> LegalizeWGSLEntryPointContext::collectSystemValFromEntryPoint(
- EntryPointInfo entryPoint)
-{
- List<SystemValLegalizationWorkItem> systemValWorkItems;
- for (auto param : entryPoint.entryPointFunc->getParams())
+ auto layout = parentLayoutDecor->getLayout();
+ // Find the exact same decoration `instToReplace` in-case multiple of the same type exist
+ List<IRInst*> opList;
+ opList.add(instToReplaceWith);
+ for (UInt i = 0; i < layout->getOperandCount(); i++)
+ {
+ if (layout->getOperand(i) != instToReplace)
+ opList.add(layout->getOperand(i));
+ }
+ auto newLayoutDecor = builder.addLayoutDecoration(
+ parentLayoutDecor->getParent(),
+ builder.getVarLayout(opList));
+ parentLayoutDecor->removeAndDeallocate();
+ return newLayoutDecor;
+ }
+
+ IRLayoutDecoration* _simplifyUserSemanticNames(
+ IRBuilder& builder,
+ IRLayoutDecoration* layoutDecor)
{
- auto maybeWorkItem = makeSystemValWorkItem(param);
- if (maybeWorkItem.has_value())
- systemValWorkItems.add(std::move(maybeWorkItem.value()));
+ // Ensure all 'ExplicitIndex' semantics such as "SV_TARGET0" are simplified into
+ // ("SV_TARGET", 0) using 'IRUserSemanticAttr' This is done to ensure we can check semantic
+ // groups using 'IRUserSemanticAttr1->getName() == IRUserSemanticAttr2->getName()'
+ SLANG_ASSERT(layoutDecor);
+ auto layout = layoutDecor->getLayout();
+ List<IRInst*> layoutOps;
+ layoutOps.reserve(3);
+ bool changed = false;
+ for (auto attr : layout->getAllAttrs())
+ {
+ if (auto userSemantic = as<IRUserSemanticAttr>(attr))
+ {
+ UnownedStringSlice outName;
+ UnownedStringSlice outIndex;
+ bool hasStringIndex = splitNameAndIndex(userSemantic->getName(), outName, outIndex);
+
+ changed = true;
+ auto newDecoration = builder.getUserSemanticAttr(
+ userSemanticName,
+ hasStringIndex ? stringToInt(outIndex) : 0);
+ userSemantic->replaceUsesWith(newDecoration);
+ userSemantic->removeAndDeallocate();
+ userSemantic = newDecoration;
+
+ layoutOps.add(userSemantic);
+ continue;
+ }
+ layoutOps.add(attr);
+ }
+ if (changed)
+ {
+ auto parent = layoutDecor->parent;
+ layoutDecor->removeAndDeallocate();
+ builder.addLayoutDecoration(parent, builder.getVarLayout(layoutOps));
+ }
+ return layoutDecor;
}
- return systemValWorkItems;
-}
-void LegalizeWGSLEntryPointContext::legalizeSystemValue(
- EntryPointInfo entryPoint,
- SystemValLegalizationWorkItem& workItem)
-{
- IRBuilder builder(entryPoint.entryPointFunc);
+ // Find overlapping field semantics and legalize them
+ void fixFieldSemanticsOfFlatStruct(IRStructType* structType)
+ {
+ // Goal is to ensure we do not have overlapping semantics for the user defined semantics:
+ // Note that in WGSL, the semantics can be either `builtin` without index or `location` with
+ // index.
+ /*
+ // Assume the following code
+ struct Fragment
+ {
+ float4 p0 : SV_POSITION;
+ float2 p1 : TEXCOORD0;
+ float2 p2 : TEXCOORD1;
+ float3 p3 : COLOR0;
+ float3 p4 : COLOR1;
+ };
- auto var = workItem.var;
- auto semanticName = workItem.attrName;
+ // Translates into
+ struct Fragment
+ {
+ float4 p0 : BUILTIN_POSITION;
+ float2 p1 : LOCATION_0;
+ float2 p2 : LOCATION_1;
+ float3 p3 : LOCATION_2;
+ float3 p4 : LOCATION_3;
+ };
+ */
- auto indexAsString = String(workItem.attrIndex);
- auto info = getSystemValueInfo(semanticName, &indexAsString, var);
+ // For Multi-Render-Target, the semantic index must be translated to `location` with
+ // the same index. Assume the following code
+ /*
+ struct Fragment
+ {
+ float4 p0 : SV_TARGET1;
+ float4 p1 : SV_TARGET0;
+ };
- if (!info.permittedTypes.getCount())
- return;
+ // Translates into
+ struct Fragment
+ {
+ float4 p0 : LOCATION_1;
+ float4 p1 : LOCATION_0;
+ };
+ */
- builder.addTargetSystemValueDecoration(var, info.wgslSystemValueName.getUnownedSlice());
+ IRBuilder builder(this->m_module);
- bool varTypeIsPermitted = false;
- auto varType = var->getFullType();
- for (auto& permittedType : info.permittedTypes)
- {
- varTypeIsPermitted = varTypeIsPermitted || permittedType == varType;
+ List<IRSemanticDecoration*> overlappingSemanticsDecor;
+ Dictionary<UnownedStringSlice, std::set<UInt, std::less<UInt>>>
+ usedSemanticIndexSemanticDecor;
+
+ List<AttributeParentPair<IRVarOffsetAttr>> overlappingVarOffset;
+ Dictionary<UInt, std::set<UInt, std::less<UInt>>> usedSemanticIndexVarOffset;
+
+ List<AttributeParentPair<IRUserSemanticAttr>> overlappingUserSemantic;
+ Dictionary<UnownedStringSlice, std::set<UInt, std::less<UInt>>>
+ usedSemanticIndexUserSemantic;
+
+ // We store a map from old `IRLayoutDecoration*` to new `IRLayoutDecoration*` since when
+ // legalizing we may destroy and remake a `IRLayoutDecoration*`
+ Dictionary<IRLayoutDecoration*, IRLayoutDecoration*> oldLayoutDecorToNew;
+
+ // Collect all "semantic info carrying decorations". Any collected decoration will
+ // fill up their respective 'Dictionary<SEMANTIC_TYPE, OrderedHashSet<UInt>>'
+ // to keep track of in-use offsets for a semantic type.
+ // Example: IRSemanticDecoration with name of "SV_TARGET1".
+ // * This will have SEMANTIC_TYPE of "sv_target".
+ // * This will use up index '1'
+ //
+ // Now if a second equal semantic "SV_TARGET1" is found, we add this decoration to
+ // a list of 'overlapping semantic info decorations' so we can legalize this
+ // 'semantic info decoration' later.
+ //
+ // NOTE: this is a flat struct, all members are children of the initial
+ // IRStructType.
+ for (auto field : structType->getFields())
+ {
+ auto key = field->getKey();
+ if (auto semanticDecoration = key->findDecoration<IRSemanticDecoration>())
+ {
+ auto semanticName = semanticDecoration->getSemanticName();
+
+ // sv_target is treated as a user-semantic because it should be emitted with
+ // @location like how the user semantics are emitted.
+ // For fragment shader, only sv_target will user @location, and for non-fragment
+ // shaders, sv_target is not valid.
+ bool isUserSemantic =
+ (semanticName.startsWithCaseInsensitive(toSlice("sv_target")) ||
+ !semanticName.startsWithCaseInsensitive(toSlice("sv_")));
+
+ // Ensure names are in a uniform lowercase format so we can bunch together simmilar
+ // semantics.
+ UnownedStringSlice outName;
+ UnownedStringSlice outIndex;
+ bool hasStringIndex = splitNameAndIndex(semanticName, outName, outIndex);
+
+ // user semantics gets all same semantic-name.
+ auto loweredName = String(outName).toLower();
+ auto loweredNameSlice =
+ isUserSemantic ? userSemanticName : loweredName.getUnownedSlice();
+ auto newDecoration = builder.addSemanticDecoration(
+ key,
+ loweredNameSlice,
+ hasStringIndex ? stringToInt(outIndex) : 0);
+ semanticDecoration->replaceUsesWith(newDecoration);
+ semanticDecoration->removeAndDeallocate();
+ semanticDecoration = newDecoration;
+
+ auto& semanticUse =
+ usedSemanticIndexSemanticDecor[semanticDecoration->getSemanticName()];
+ if (semanticUse.find(semanticDecoration->getSemanticIndex()) != semanticUse.end())
+ overlappingSemanticsDecor.add(semanticDecoration);
+ else
+ semanticUse.insert(semanticDecoration->getSemanticIndex());
+ }
+ if (auto layoutDecor = key->findDecoration<IRLayoutDecoration>())
+ {
+ // Ensure names are in a uniform lowercase format so we can bunch together simmilar
+ // semantics
+ layoutDecor = _simplifyUserSemanticNames(builder, layoutDecor);
+ oldLayoutDecorToNew[layoutDecor] = layoutDecor;
+ auto layout = layoutDecor->getLayout();
+ for (auto attr : layout->getAllAttrs())
+ {
+ if (auto offset = as<IRVarOffsetAttr>(attr))
+ {
+ auto& semanticUse = usedSemanticIndexVarOffset[offset->getResourceKind()];
+ if (semanticUse.find(offset->getOffset()) != semanticUse.end())
+ overlappingVarOffset.add({layoutDecor, offset});
+ else
+ semanticUse.insert(offset->getOffset());
+ }
+ else if (auto userSemantic = as<IRUserSemanticAttr>(attr))
+ {
+ auto& semanticUse = usedSemanticIndexUserSemantic[userSemantic->getName()];
+ if (semanticUse.find(userSemantic->getIndex()) != semanticUse.end())
+ overlappingUserSemantic.add({layoutDecor, userSemantic});
+ else
+ semanticUse.insert(userSemantic->getIndex());
+ }
+ }
+ }
+ }
+
+ // Legalize all overlapping 'semantic info decorations'
+ for (auto decor : overlappingSemanticsDecor)
+ {
+ auto newOffset = _returnNonOverlappingAttributeIndex(
+ usedSemanticIndexSemanticDecor[decor->getSemanticName()]);
+ builder.addSemanticDecoration(
+ decor->getParent(),
+ decor->getSemanticName(),
+ (int)newOffset);
+ decor->removeAndDeallocate();
+ }
+ for (auto& varOffset : overlappingVarOffset)
+ {
+ auto newOffset = _returnNonOverlappingAttributeIndex(
+ usedSemanticIndexVarOffset[varOffset.attr->getResourceKind()]);
+ auto newVarOffset = builder.getVarOffsetAttr(
+ varOffset.attr->getResourceKind(),
+ newOffset,
+ varOffset.attr->getSpace());
+ oldLayoutDecorToNew[varOffset.layoutDecor] = _replaceAttributeOfLayout(
+ builder,
+ oldLayoutDecorToNew[varOffset.layoutDecor],
+ varOffset.attr,
+ newVarOffset);
+ }
+ for (auto& userSemantic : overlappingUserSemantic)
+ {
+ auto newOffset = _returnNonOverlappingAttributeIndex(
+ usedSemanticIndexUserSemantic[userSemantic.attr->getName()]);
+ auto newUserSemantic =
+ builder.getUserSemanticAttr(userSemantic.attr->getName(), newOffset);
+ oldLayoutDecorToNew[userSemantic.layoutDecor] = _replaceAttributeOfLayout(
+ builder,
+ oldLayoutDecorToNew[userSemantic.layoutDecor],
+ userSemantic.attr,
+ newUserSemantic);
+ }
}
- if (!varTypeIsPermitted)
+ void wrapReturnValueInStruct(EntryPointInfo entryPoint)
{
- // Note: we do not currently prefer any conversion
- // example:
- // * allowed types for semantic: `float4`, `uint4`, `int4`
- // * user used, `float2`
- // * Slang will equally prefer `float4` to `uint4` to `int4`.
- // This means the type may lose data if slang selects `uint4` or `int4`.
- bool foundAConversion = false;
- for (auto permittedType : info.permittedTypes)
- {
- var->setFullType(permittedType);
- builder.setInsertBefore(
- entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst());
-
- // get uses before we `tryConvertValue` since this creates a new use
- List<IRUse*> uses;
- for (auto use = var->firstUse; use; use = use->nextUse)
- uses.add(use);
-
- auto convertedValue = tryConvertValue(builder, var, varType);
- if (convertedValue == nullptr)
- continue;
+ // Wrap return value into a struct if it is not already a struct.
+ // For example, given this entry point:
+ // ```
+ // float4 main() : SV_Target { return float3(1,2,3); }
+ // ```
+ // We are going to transform it into:
+ // ```
+ // struct Output {
+ // float4 value : SV_Target;
+ // };
+ // Output main() { return {float3(1,2,3)}; }
- foundAConversion = true;
- copyNameHintAndDebugDecorations(convertedValue, var);
+ auto func = entryPoint.entryPointFunc;
- for (auto use : uses)
- builder.replaceOperand(use, convertedValue);
+ auto returnType = func->getResultType();
+ if (as<IRVoidType>(returnType))
+ return;
+ auto entryPointLayoutDecor = func->findDecoration<IRLayoutDecoration>();
+ if (!entryPointLayoutDecor)
+ return;
+ auto entryPointLayout = as<IREntryPointLayout>(entryPointLayoutDecor->getLayout());
+ if (!entryPointLayout)
+ return;
+ auto resultLayout = entryPointLayout->getResultLayout();
+
+ // If return type is already a struct, just make sure every field has a semantic.
+ if (auto returnStructType = as<IRStructType>(returnType))
+ {
+ IRBuilder builder(func);
+ MapStructToFlatStruct mapOldFieldToNewField;
+ // Flatten result struct type to ensure we do not have nested semantics
+ auto flattenedStruct = maybeFlattenNestedStructs(
+ builder,
+ returnStructType,
+ mapOldFieldToNewField,
+ semanticInfoToRemove);
+ if (returnStructType != flattenedStruct)
+ {
+ // Replace all return-values with the flattenedStruct we made.
+ _replaceAllReturnInst(
+ builder,
+ func,
+ flattenedStruct,
+ [&](IRBuilder& copyBuilder, IRStructType* dstType, IRInst* srcVal) -> IRInst*
+ {
+ auto srcStructType = as<IRStructType>(srcVal->getDataType());
+ SLANG_ASSERT(srcStructType);
+ auto dstVal = copyBuilder.emitVar(dstType);
+ mapOldFieldToNewField.emitCopy<(
+ int)MapStructToFlatStruct::CopyOptions::StructIntoFlatStruct>(
+ copyBuilder,
+ dstVal,
+ srcVal);
+ return builder.emitLoad(dstVal);
+ });
+ fixUpFuncType(func, flattenedStruct);
+ }
+ // Ensure non-overlapping semantics
+ fixFieldSemanticsOfFlatStruct(flattenedStruct);
+ ensureResultStructHasUserSemantic(flattenedStruct, resultLayout);
+ return;
}
- if (!foundAConversion)
+
+ IRBuilder builder(func);
+ builder.setInsertBefore(func);
+ IRStructType* structType = builder.createStructType();
+ auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage());
+ builder.addNameHintDecoration(
+ structType,
+ (String(stageText) + toSlice("Output")).getUnownedSlice());
+ auto key = builder.createStructKey();
+ builder.addNameHintDecoration(key, toSlice("output"));
+ builder.addLayoutDecoration(key, resultLayout);
+ builder.createStructField(structType, key, returnType);
+ IRStructTypeLayout::Builder structTypeLayoutBuilder(&builder);
+ structTypeLayoutBuilder.addField(key, resultLayout);
+ auto typeLayout = structTypeLayoutBuilder.build();
+ IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout);
+ auto varLayout = varLayoutBuilder.build();
+ ensureResultStructHasUserSemantic(structType, varLayout);
+
+ _replaceAllReturnInst(
+ builder,
+ func,
+ structType,
+ [](IRBuilder& copyBuilder, IRStructType* dstType, IRInst* srcVal) -> IRInst*
+ { return copyBuilder.emitMakeStruct(dstType, 1, &srcVal); });
+
+ // Assign an appropriate system value semantic for stage output
+ auto stage = entryPoint.entryPointDecor->getProfile().getStage();
+ switch (stage)
{
- // If we can't convert the value, report an error.
- for (auto permittedType : info.permittedTypes)
+ case Stage::Compute:
+ case Stage::Fragment:
{
- StringBuilder typeNameSB;
- getTypeNameHint(typeNameSB, permittedType);
- m_sink->diagnose(
- var->sourceLoc,
- Diagnostics::systemValueTypeIncompatible,
- semanticName,
- typeNameSB.produceString());
+ IRInst* operands[] = {
+ builder.getStringValue(userSemanticName),
+ builder.getIntValue(builder.getIntType(), 0)};
+ builder.addDecoration(
+ key,
+ kIROp_SemanticDecoration,
+ operands,
+ SLANG_COUNT_OF(operands));
+ break;
}
+ case Stage::Vertex:
+ {
+ builder.addTargetSystemValueDecoration(key, toSlice("position"));
+ break;
+ }
+ default:
+ SLANG_ASSERT(false);
+ return;
}
- }
-}
-void LegalizeWGSLEntryPointContext::legalizeSystemValueParameters(EntryPointInfo entryPoint)
-{
- List<SystemValLegalizationWorkItem> systemValWorkItems =
- collectSystemValFromEntryPoint(entryPoint);
+ fixUpFuncType(func, structType);
+ }
- for (auto index = 0; index < systemValWorkItems.getCount(); index++)
+ IRInst* tryConvertValue(IRBuilder& builder, IRInst* val, IRType* toType)
{
- legalizeSystemValue(entryPoint, systemValWorkItems[index]);
+ auto fromType = val->getFullType();
+ if (auto fromVector = as<IRVectorType>(fromType))
+ {
+ if (auto toVector = as<IRVectorType>(toType))
+ {
+ if (fromVector->getElementCount() != toVector->getElementCount())
+ {
+ fromType = builder.getVectorType(
+ fromVector->getElementType(),
+ toVector->getElementCount());
+ val = builder.emitVectorReshape(fromType, val);
+ }
+ }
+ else if (as<IRBasicType>(toType))
+ {
+ UInt index = 0;
+ val = builder.emitSwizzle(fromVector->getElementType(), val, 1, &index);
+ if (toType->getOp() == kIROp_VoidType)
+ return nullptr;
+ }
+ }
+ else if (auto fromBasicType = as<IRBasicType>(fromType))
+ {
+ if (fromBasicType->getOp() == kIROp_VoidType)
+ return nullptr;
+ if (!as<IRBasicType>(toType))
+ return nullptr;
+ if (toType->getOp() == kIROp_VoidType)
+ return nullptr;
+ }
+ else
+ {
+ return nullptr;
+ }
+ return builder.emitCast(toType, val);
}
-}
-void LegalizeWGSLEntryPointContext::legalizeEntryPointForWGSL(EntryPointInfo entryPoint)
-{
- legalizeSystemValueParameters(entryPoint);
-}
-
-void LegalizeWGSLEntryPointContext::legalizeCall(IRCall* call)
-{
- // WGSL does not allow forming a pointer to a sub part of a composite value.
- // For example, if we have
- // ```
- // struct S { float x; float y; };
- // void foo(inout float v) { v = 1.0f; }
- // void main() { S s; foo(s.x); }
- // ```
- // The call to `foo(s.x)` is illegal in WGSL because `s.x` is a sub part of `s`.
- // And trying to form `&s.x` in WGSL is illegal.
- // To work around this, we will create a local variable to hold the sub part of
- // the composite value.
- // And then pass the local variable to the function.
- // After the call, we will write back the local variable to the sub part of the
- // composite value.
- //
- IRBuilder builder(call);
- builder.setInsertBefore(call);
- struct WritebackPair
+ struct SystemValLegalizationWorkItem
{
- IRInst* dest;
- IRInst* value;
+ IRInst* var;
+ IRType* varType;
+ String attrName;
+ UInt attrIndex;
};
- ShortList<WritebackPair> pendingWritebacks;
- for (UInt i = 0; i < call->getArgCount(); i++)
+ std::optional<SystemValLegalizationWorkItem> tryToMakeSystemValWorkItem(
+ IRInst* var,
+ IRType* varType)
{
- auto arg = call->getArg(i);
- auto ptrType = as<IRPtrTypeBase>(arg->getDataType());
- if (!ptrType)
- continue;
- switch (arg->getOp())
+ if (auto semanticDecoration = var->findDecoration<IRSemanticDecoration>())
{
- case kIROp_Var:
- case kIROp_Param:
- continue;
- default:
- break;
+ if (semanticDecoration->getSemanticName().startsWithCaseInsensitive(toSlice("sv_")))
+ {
+ return {
+ {var,
+ varType,
+ String(semanticDecoration->getSemanticName()).toLower(),
+ (UInt)semanticDecoration->getSemanticIndex()}};
+ }
+ }
+
+ auto layoutDecor = var->findDecoration<IRLayoutDecoration>();
+ if (!layoutDecor)
+ return {};
+ auto sysValAttr = layoutDecor->findAttr<IRSystemValueSemanticAttr>();
+ if (!sysValAttr)
+ return {};
+ auto semanticName = String(sysValAttr->getName());
+ auto sysAttrIndex = sysValAttr->getIndex();
+
+ return {{var, varType, semanticName, sysAttrIndex}};
+ }
+
+ List<SystemValLegalizationWorkItem> collectSystemValFromEntryPoint(EntryPointInfo entryPoint)
+ {
+ List<SystemValLegalizationWorkItem> systemValWorkItems;
+ for (auto param : entryPoint.entryPointFunc->getParams())
+ {
+ if (auto structType = as<IRStructType>(param->getDataType()))
+ {
+ for (auto field : structType->getFields())
+ {
+ // Nested struct-s are flattened already by flattenInputParameters().
+ SLANG_ASSERT(!as<IRStructType>(field->getFieldType()));
+
+ auto key = field->getKey();
+ auto fieldType = field->getFieldType();
+ auto maybeWorkItem = tryToMakeSystemValWorkItem(key, fieldType);
+ if (maybeWorkItem.has_value())
+ systemValWorkItems.add(std::move(maybeWorkItem.value()));
+ }
+ continue;
+ }
+
+ auto maybeWorkItem = tryToMakeSystemValWorkItem(param, param->getFullType());
+ if (maybeWorkItem.has_value())
+ systemValWorkItems.add(std::move(maybeWorkItem.value()));
}
+ return systemValWorkItems;
+ }
+
+ void legalizeSystemValue(EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem)
+ {
+ IRBuilder builder(entryPoint.entryPointFunc);
+
+ auto var = workItem.var;
+ auto varType = workItem.varType;
+ auto semanticName = workItem.attrName;
- // Create a local variable to hold the input argument.
- auto var = builder.emitVar(ptrType->getValueType(), AddressSpace::Function);
+ auto indexAsString = String(workItem.attrIndex);
+ auto info = getSystemValueInfo(semanticName, &indexAsString, var);
- // Store the input argument into the local variable.
- builder.emitStore(var, builder.emitLoad(arg));
- builder.replaceOperand(call->getArgs() + i, var);
- pendingWritebacks.add({arg, var});
+ if (info.isUnsupported)
+ {
+ reportUnsupportedSystemAttribute(var, semanticName);
+ return;
+ }
+ if (!info.permittedTypes.getCount())
+ return;
+
+ builder.addTargetSystemValueDecoration(var, info.wgslSystemValueName.getUnownedSlice());
+
+ bool varTypeIsPermitted = false;
+ for (auto& permittedType : info.permittedTypes)
+ {
+ varTypeIsPermitted = varTypeIsPermitted || permittedType == varType;
+ }
+
+ if (!varTypeIsPermitted)
+ {
+ // Note: we do not currently prefer any conversion
+ // example:
+ // * allowed types for semantic: `float4`, `uint4`, `int4`
+ // * user used, `float2`
+ // * Slang will equally prefer `float4` to `uint4` to `int4`.
+ // This means the type may lose data if slang selects `uint4` or `int4`.
+ bool foundAConversion = false;
+ for (auto permittedType : info.permittedTypes)
+ {
+ var->setFullType(permittedType);
+ builder.setInsertBefore(
+ entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst());
+
+ // get uses before we `tryConvertValue` since this creates a new use
+ List<IRUse*> uses;
+ for (auto use = var->firstUse; use; use = use->nextUse)
+ uses.add(use);
+
+ auto convertedValue = tryConvertValue(builder, var, varType);
+ if (convertedValue == nullptr)
+ continue;
+
+ foundAConversion = true;
+ copyNameHintAndDebugDecorations(convertedValue, var);
+
+ for (auto use : uses)
+ builder.replaceOperand(use, convertedValue);
+ }
+ if (!foundAConversion)
+ {
+ // If we can't convert the value, report an error.
+ for (auto permittedType : info.permittedTypes)
+ {
+ StringBuilder typeNameSB;
+ getTypeNameHint(typeNameSB, permittedType);
+ m_sink->diagnose(
+ var->sourceLoc,
+ Diagnostics::systemValueTypeIncompatible,
+ semanticName,
+ typeNameSB.produceString());
+ }
+ }
+ }
}
- // Perform writebacks after the call.
- builder.setInsertAfter(call);
- for (auto& pair : pendingWritebacks)
+ void legalizeSystemValueParameters(EntryPointInfo entryPoint)
{
- builder.emitStore(pair.dest, builder.emitLoad(pair.value));
+ List<SystemValLegalizationWorkItem> systemValWorkItems =
+ collectSystemValFromEntryPoint(entryPoint);
+
+ for (auto index = 0; index < systemValWorkItems.getCount(); index++)
+ {
+ legalizeSystemValue(entryPoint, systemValWorkItems[index]);
+ }
+ fixUpFuncType(entryPoint.entryPointFunc);
}
-}
-void LegalizeWGSLEntryPointContext::legalizeSwitch(IRSwitch* switchInst)
-{
- // WGSL Requires all switch statements to contain a default case.
- // If the switch statement does not contain a default case, we will add one.
- if (switchInst->getDefaultLabel() != switchInst->getBreakLabel())
- return;
- IRBuilder builder(switchInst);
- auto defaultBlock = builder.createBlock();
- builder.setInsertInto(defaultBlock);
- builder.emitBranch(switchInst->getBreakLabel());
- defaultBlock->insertBefore(switchInst->getBreakLabel());
- List<IRInst*> cases;
- for (UInt i = 0; i < switchInst->getCaseCount(); i++)
+ void legalizeEntryPointForWGSL(EntryPointInfo entryPoint)
{
- cases.add(switchInst->getCaseValue(i));
- cases.add(switchInst->getCaseLabel(i));
+ // Input Parameter Legalize
+ flattenInputParameters(entryPoint);
+
+ // System Value Legalize
+ legalizeSystemValueParameters(entryPoint);
+
+ // Output Value Legalize
+ wrapReturnValueInStruct(entryPoint);
}
- builder.setInsertBefore(switchInst);
- auto newSwitch = builder.emitSwitch(
- switchInst->getCondition(),
- switchInst->getBreakLabel(),
- defaultBlock,
- (UInt)cases.getCount(),
- cases.getBuffer());
- switchInst->transferDecorationsTo(newSwitch);
- switchInst->removeAndDeallocate();
-}
-void LegalizeWGSLEntryPointContext::legalizeBinaryOp(IRInst* inst)
-{
- auto isVectorOrMatrix = [](IRType* type)
+ void legalizeCall(IRCall* call)
{
- switch (type->getOp())
+ // WGSL does not allow forming a pointer to a sub part of a composite value.
+ // For example, if we have
+ // ```
+ // struct S { float x; float y; };
+ // void foo(inout float v) { v = 1.0f; }
+ // void main() { S s; foo(s.x); }
+ // ```
+ // The call to `foo(s.x)` is illegal in WGSL because `s.x` is a sub part of `s`.
+ // And trying to form `&s.x` in WGSL is illegal.
+ // To work around this, we will create a local variable to hold the sub part of
+ // the composite value.
+ // And then pass the local variable to the function.
+ // After the call, we will write back the local variable to the sub part of the
+ // composite value.
+ //
+ IRBuilder builder(call);
+ builder.setInsertBefore(call);
+ struct WritebackPair
{
- case kIROp_VectorType:
- case kIROp_MatrixType:
- return true;
- default:
- return false;
+ IRInst* dest;
+ IRInst* value;
+ };
+ ShortList<WritebackPair> pendingWritebacks;
+
+ for (UInt i = 0; i < call->getArgCount(); i++)
+ {
+ auto arg = call->getArg(i);
+ auto ptrType = as<IRPtrTypeBase>(arg->getDataType());
+ if (!ptrType)
+ continue;
+ switch (arg->getOp())
+ {
+ case kIROp_Var:
+ case kIROp_Param:
+ continue;
+ default:
+ break;
+ }
+
+ // Create a local variable to hold the input argument.
+ auto var = builder.emitVar(ptrType->getValueType(), AddressSpace::Function);
+
+ // Store the input argument into the local variable.
+ builder.emitStore(var, builder.emitLoad(arg));
+ builder.replaceOperand(call->getArgs() + i, var);
+ pendingWritebacks.add({arg, var});
}
- };
- if (isVectorOrMatrix(inst->getOperand(0)->getDataType()) &&
- as<IRBasicType>(inst->getOperand(1)->getDataType()))
+
+ // Perform writebacks after the call.
+ builder.setInsertAfter(call);
+ for (auto& pair : pendingWritebacks)
+ {
+ builder.emitStore(pair.dest, builder.emitLoad(pair.value));
+ }
+ }
+
+ void legalizeSwitch(IRSwitch* switchInst)
{
- IRBuilder builder(inst);
- builder.setInsertBefore(inst);
- auto newRhs = builder.emitMakeCompositeFromScalar(
- inst->getOperand(0)->getDataType(),
- inst->getOperand(1));
- builder.replaceOperand(inst->getOperands() + 1, newRhs);
+ // WGSL Requires all switch statements to contain a default case.
+ // If the switch statement does not contain a default case, we will add one.
+ if (switchInst->getDefaultLabel() != switchInst->getBreakLabel())
+ return;
+ IRBuilder builder(switchInst);
+ auto defaultBlock = builder.createBlock();
+ builder.setInsertInto(defaultBlock);
+ builder.emitBranch(switchInst->getBreakLabel());
+ defaultBlock->insertBefore(switchInst->getBreakLabel());
+ List<IRInst*> cases;
+ for (UInt i = 0; i < switchInst->getCaseCount(); i++)
+ {
+ cases.add(switchInst->getCaseValue(i));
+ cases.add(switchInst->getCaseLabel(i));
+ }
+ builder.setInsertBefore(switchInst);
+ auto newSwitch = builder.emitSwitch(
+ switchInst->getCondition(),
+ switchInst->getBreakLabel(),
+ defaultBlock,
+ (UInt)cases.getCount(),
+ cases.getBuffer());
+ switchInst->transferDecorationsTo(newSwitch);
+ switchInst->removeAndDeallocate();
}
- else if (
- as<IRBasicType>(inst->getOperand(0)->getDataType()) &&
- isVectorOrMatrix(inst->getOperand(1)->getDataType()))
+
+ void legalizeBinaryOp(IRInst* inst)
{
- IRBuilder builder(inst);
- builder.setInsertBefore(inst);
- auto newLhs = builder.emitMakeCompositeFromScalar(
- inst->getOperand(1)->getDataType(),
- inst->getOperand(0));
- builder.replaceOperand(inst->getOperands(), newLhs);
+ auto isVectorOrMatrix = [](IRType* type)
+ {
+ switch (type->getOp())
+ {
+ case kIROp_VectorType:
+ case kIROp_MatrixType:
+ return true;
+ default:
+ return false;
+ }
+ };
+ if (isVectorOrMatrix(inst->getOperand(0)->getDataType()) &&
+ as<IRBasicType>(inst->getOperand(1)->getDataType()))
+ {
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto newRhs = builder.emitMakeCompositeFromScalar(
+ inst->getOperand(0)->getDataType(),
+ inst->getOperand(1));
+ builder.replaceOperand(inst->getOperands() + 1, newRhs);
+ }
+ else if (
+ as<IRBasicType>(inst->getOperand(0)->getDataType()) &&
+ isVectorOrMatrix(inst->getOperand(1)->getDataType()))
+ {
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto newLhs = builder.emitMakeCompositeFromScalar(
+ inst->getOperand(1)->getDataType(),
+ inst->getOperand(0));
+ builder.replaceOperand(inst->getOperands(), newLhs);
+ }
}
-}
-void LegalizeWGSLEntryPointContext::processInst(IRInst* inst)
-{
- switch (inst->getOp())
+ void processInst(IRInst* inst)
{
- case kIROp_Call:
- legalizeCall(static_cast<IRCall*>(inst));
- break;
- case kIROp_Switch:
- legalizeSwitch(as<IRSwitch>(inst));
- break;
-
- // For all binary operators, make sure both side of the operator have the same type
- // (vector-ness and matrix-ness).
- case kIROp_Add:
- case kIROp_Sub:
- case kIROp_Mul:
- case kIROp_Div:
- case kIROp_FRem:
- case kIROp_IRem:
- case kIROp_And:
- case kIROp_Or:
- case kIROp_BitAnd:
- case kIROp_BitOr:
- case kIROp_BitXor:
- case kIROp_Lsh:
- case kIROp_Rsh:
- case kIROp_Eql:
- case kIROp_Neq:
- case kIROp_Greater:
- case kIROp_Less:
- case kIROp_Geq:
- case kIROp_Leq:
- legalizeBinaryOp(inst);
- break;
-
- default:
- for (auto child : inst->getModifiableChildren())
- processInst(child);
+ switch (inst->getOp())
+ {
+ case kIROp_Call:
+ legalizeCall(static_cast<IRCall*>(inst));
+ break;
+
+ case kIROp_Switch:
+ legalizeSwitch(as<IRSwitch>(inst));
+ break;
+
+ // For all binary operators, make sure both side of the operator have the same type
+ // (vector-ness and matrix-ness).
+ case kIROp_Add:
+ case kIROp_Sub:
+ case kIROp_Mul:
+ case kIROp_Div:
+ case kIROp_FRem:
+ case kIROp_IRem:
+ case kIROp_And:
+ case kIROp_Or:
+ case kIROp_BitAnd:
+ case kIROp_BitOr:
+ case kIROp_BitXor:
+ case kIROp_Lsh:
+ case kIROp_Rsh:
+ case kIROp_Eql:
+ case kIROp_Neq:
+ case kIROp_Greater:
+ case kIROp_Less:
+ case kIROp_Geq:
+ case kIROp_Leq:
+ legalizeBinaryOp(inst);
+ break;
+
+ default:
+ for (auto child : inst->getModifiableChildren())
+ {
+ processInst(child);
+ }
+ }
}
-}
+};
+
void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink)
{
List<EntryPointInfo> entryPoints;
@@ -484,7 +1419,9 @@ void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink)
LegalizeWGSLEntryPointContext context(sink, module);
for (auto entryPoint : entryPoints)
+ {
context.legalizeEntryPointForWGSL(entryPoint);
+ }
// Go through every instruction in the module and legalize them as needed.
context.processInst(module->getModuleInst());