diff options
| author | Jay Kwak <82421531+jkwak-work@users.noreply.github.com> | 2024-11-05 16:31:47 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-11-05 16:31:47 -0800 |
| commit | 79056cd7e0ba261a007e21a98a6f49cb0b032e25 (patch) | |
| tree | f08c26c9f16ddbfb4a890ce7d201f27d037ccd03 | |
| parent | 4fa76f374c0c35c9c7d186e8addf6861e98baaec (diff) | |
Legalize the Entry-point for WGSL (#5498)
* Legalize the Entry-point for WGSL
The return type of the entry-point needs to be legalized when targeting
WGSL.
This commit flattens the nested-structs of the return type and the input
parameters of the entry-point.
Most of code is copied from the legalization code for Metal. The
following functions are exactly same to the implementation for Metal or
almost same.
- flattenInputParameters() : 136 lines
- reportUnsupportedSystemAttribute() : 7 lines
- ensureResultStructHasUserSemantic() : 46 lines
- struct MapStructToFlatStruct : 176 lines
- flattenNestedStructs() : 95 lines
- maybeFlattenNestedStructs() : 42 lines
- _replaceAllReturnInst() : 19 lines
- _returnNonOverlappingAttributeIndex() : 16 lines
- _replaceAttributeOfLayout() : 23 lines
- tryConvertValue() : 41 lines
- legalizeSystemValueParameters() : 11 lines
They need to be refactored to reduce the duplication later.
The test case, `tests/compute/assoctype-lookup.slang`, had a bug that
the compute shader was trying to use the varying input/output with the
user defined semantics.
This commit removes the user defined semantics, because the compute
shaders cannot use the user defined semantics.
---------
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
| -rw-r--r-- | source/slang/slang-emit-c-like.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-emit-c-like.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-emit-wgsl.cpp | 26 | ||||
| -rw-r--r-- | source/slang/slang-emit-wgsl.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-wgsl-legalize.cpp | 1629 | ||||
| -rw-r--r-- | tests/compute/assoctype-lookup.slang | 6 | ||||
| -rw-r--r-- | tests/wgsl/nested-varying-input.slang | 53 |
7 files changed, 1376 insertions, 349 deletions
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 5d04a50db..020c31fdc 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -3293,6 +3293,11 @@ void CLikeSourceEmitter::emitSemanticsUsingVarLayout(IRVarLayout* varLayout) } } +void CLikeSourceEmitter::emitSemanticsPrefix(IRInst* inst) +{ + emitSemanticsPrefixImpl(inst); +} + void CLikeSourceEmitter::emitSemantics(IRInst* inst, bool allowOffsetLayout) { emitSemanticsImpl(inst, allowOffsetLayout); @@ -3869,6 +3874,7 @@ void CLikeSourceEmitter::emitStructDeclarationsBlock( emitPackOffsetModifier(fieldKey, fieldType, packOffsetDecoration); } } + emitSemanticsPrefix(fieldKey); emitStructFieldAttributes(structType, ff); emitMemoryQualifiers(fieldKey); emitType(fieldType, getName(fieldKey)); diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index 9f30c4f41..1da3a64dc 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -355,6 +355,7 @@ public: void diagnoseUnhandledInst(IRInst* inst); void emitInst(IRInst* inst); + void emitSemanticsPrefix(IRInst* inst); void emitSemantics(IRInst* inst, bool allowOffsets = false); void emitSemanticsUsingVarLayout(IRVarLayout* varLayout); @@ -557,6 +558,7 @@ protected: SLANG_UNUSED(rate); SLANG_UNUSED(addressSpace); } + virtual void emitSemanticsPrefixImpl(IRInst* inst) { SLANG_UNUSED(inst); } virtual void emitSemanticsImpl(IRInst* inst, bool allowOffsetLayout) { SLANG_UNUSED(inst); diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp index dea95c6ec..256697bc7 100644 --- a/source/slang/slang-emit-wgsl.cpp +++ b/source/slang/slang-emit-wgsl.cpp @@ -236,6 +236,32 @@ static bool isPowerOf2(const uint32_t n) return (n != 0U) && ((n - 1U) & n) == 0U; } +bool WGSLSourceEmitter::maybeEmitSystemSemantic(IRInst* inst) +{ + if (auto sysSemanticDecor = inst->findDecoration<IRTargetSystemValueDecoration>()) + { + m_writer->emit("@builtin("); + m_writer->emit(sysSemanticDecor->getSemantic()); + m_writer->emit(")"); + return true; + } + return false; +} + +void WGSLSourceEmitter::emitSemanticsPrefixImpl(IRInst* inst) +{ + if (!maybeEmitSystemSemantic(inst)) + { + if (auto semanticDecoration = inst->findDecoration<IRSemanticDecoration>()) + { + m_writer->emit("@location("); + m_writer->emit(semanticDecoration->getSemanticIndex()); + m_writer->emit(")"); + return; + } + } +} + void WGSLSourceEmitter::emitStructFieldAttributes(IRStructType* structType, IRStructField* field) { // Tint emits errors unless we explicitly spell out the layout in some cases, so emit diff --git a/source/slang/slang-emit-wgsl.h b/source/slang/slang-emit-wgsl.h index f178d8f66..6ff9e6786 100644 --- a/source/slang/slang-emit-wgsl.h +++ b/source/slang/slang-emit-wgsl.h @@ -38,6 +38,7 @@ public: virtual void emitParamTypeImpl(IRType* type, const String& name) SLANG_OVERRIDE; virtual void _emitType(IRType* type, DeclaratorInfo* declarator) SLANG_OVERRIDE; virtual void emitFrontMatterImpl(TargetRequest* targetReq) SLANG_OVERRIDE; + virtual void emitSemanticsPrefixImpl(IRInst* inst) SLANG_OVERRIDE; virtual void emitStructFieldAttributes(IRStructType* structType, IRStructField* field) SLANG_OVERRIDE; virtual void emitCallArg(IRInst* inst) SLANG_OVERRIDE; @@ -57,6 +58,8 @@ protected: void ensurePrelude(const char* preludeText); private: + bool maybeEmitSystemSemantic(IRInst* inst); + // Emit the matrix type with 'rowCountWGSL' WGSL-rows and 'colCountWGSL' WGSL-columns void emitMatrixType( IRType* const elementType, diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp index 8ac58780d..6e554a8f8 100644 --- a/source/slang/slang-ir-wgsl-legalize.cpp +++ b/source/slang/slang-ir-wgsl-legalize.cpp @@ -6,6 +6,8 @@ #include "slang-ir.h" #include "slang-parameter-binding.h" +#include <set> + namespace Slang { @@ -15,455 +17,1388 @@ struct EntryPointInfo IREntryPointDecoration* entryPointDecor; }; -struct SystemValLegalizationWorkItem +struct LegalizeWGSLEntryPointContext { - IRInst* var; - String attrName; - UInt attrIndex; -}; + HashSet<IRStructField*> semanticInfoToRemove; + UnownedStringSlice userSemanticName = toSlice("user_semantic"); -struct WGSLSystemValueInfo -{ - String wgslSystemValueName; - SystemValueSemanticName wgslSystemValueNameEnum; - ShortList<IRType*> permittedTypes; - bool isUnsupported = false; -}; + DiagnosticSink* m_sink; + IRModule* m_module; -struct LegalizeWGSLEntryPointContext -{ LegalizeWGSLEntryPointContext(DiagnosticSink* sink, IRModule* module) : m_sink(sink), m_module(module) { } - DiagnosticSink* m_sink; - IRModule* m_module; + // Flattens all struct parameters of an entryPoint to ensure parameters are a flat struct + void flattenInputParameters(EntryPointInfo entryPoint) + { + // Goal is to ensure we have a flattened IRParam (0 nested IRStructType members). + /* + // Assume the following code + struct NestedFragment + { + float2 p3; + }; + struct Fragment + { + float4 p1; + float3 p2; + NestedFragment p3_nested; + }; + + // Fragment flattens into + struct Fragment + { + float4 p1; + float3 p2; + float2 p3; + }; + */ + + // This is important since WGSL does not allow semantic's on a struct + /* + // Assume the following code + struct NestedFragment1 + { + float2 p3; + }; + struct Fragment1 + { + float4 p1 : SV_TARGET0; + float3 p2 : SV_TARGET1; + NestedFragment p3_nested : SV_TARGET2; // error, semantic on struct + }; + + */ + + // Unlike Metal, WGSL does NOT allow semantics on members of a nested struct. + /* + // Assume the following code + struct NestedFragment + { + float2 p3; + }; + struct Fragment + { + float4 p1 : SV_TARGET0; + NestedFragment p2 : SV_TARGET1; + NestedFragment p3 : SV_TARGET2; + }; + + // Legalized with flattening + struct Fragment + { + float4 p1 : SV_TARGET0; + float2 p2 : SV_TARGET1; + float2 p3 : SV_TARGET2; + }; + */ + + auto func = entryPoint.entryPointFunc; + bool modified = false; + for (auto param : func->getParams()) + { + auto layout = findVarLayout(param); + if (!layout) + continue; + if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput)) + continue; + if (param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration)) + continue; + // If we find a IRParam with a IRStructType member, we need to flatten the entire + // IRParam + if (auto structType = as<IRStructType>(param->getDataType())) + { + IRBuilder builder(func); + MapStructToFlatStruct mapOldFieldToNewField; + + // Flatten struct if we have nested IRStructType + auto flattenedStruct = maybeFlattenNestedStructs( + builder, + structType, + mapOldFieldToNewField, + semanticInfoToRemove); + if (flattenedStruct != structType) + { + // Validate/rearange all semantics which overlap in our flat struct + fixFieldSemanticsOfFlatStruct(flattenedStruct); + + // Replace the 'old IRParam type' with a 'new IRParam type' + param->setFullType(flattenedStruct); + + // Emit a new variable at EntryPoint of 'old IRParam type' + builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); + auto dstVal = builder.emitVar(structType); + auto dstLoad = builder.emitLoad(dstVal); + param->replaceUsesWith(dstLoad); + builder.setInsertBefore(dstLoad); + // Copy the 'new IRParam type' to our 'old IRParam type' + mapOldFieldToNewField + .emitCopy<(int)MapStructToFlatStruct::CopyOptions::FlatStructIntoStruct>( + builder, + dstVal, + param); + + modified = true; + } + } + } + if (modified) + fixUpFuncType(func); + } + + struct WGSLSystemValueInfo + { + String wgslSystemValueName; + SystemValueSemanticName wgslSystemValueNameEnum; + ShortList<IRType*> permittedTypes; + bool isUnsupported = false; + WGSLSystemValueInfo() + { + // most commonly need 2 + permittedTypes.reserveOverflowBuffer(2); + } + }; - std::optional<SystemValLegalizationWorkItem> makeSystemValWorkItem(IRInst* var); - void legalizeSystemValue(EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem); - List<SystemValLegalizationWorkItem> collectSystemValFromEntryPoint(EntryPointInfo entryPoint); - void legalizeSystemValueParameters(EntryPointInfo entryPoint); - void legalizeEntryPointForWGSL(EntryPointInfo entryPoint); - IRInst* tryConvertValue(IRBuilder& builder, IRInst* val, IRType* toType); WGSLSystemValueInfo getSystemValueInfo( String inSemanticName, String* optionalSemanticIndex, - IRInst* parentVar); - void legalizeCall(IRCall* call); - void legalizeSwitch(IRSwitch* switchInst); - void legalizeBinaryOp(IRInst* inst); - void processInst(IRInst* inst); -}; - -IRInst* LegalizeWGSLEntryPointContext::tryConvertValue( - IRBuilder& builder, - IRInst* val, - IRType* toType) -{ - auto fromType = val->getFullType(); - if (auto fromVector = as<IRVectorType>(fromType)) + IRInst* parentVar) { - if (auto toVector = as<IRVectorType>(toType)) + IRBuilder builder(m_module); + WGSLSystemValueInfo result = {}; + UnownedStringSlice semanticName; + UnownedStringSlice semanticIndex; + + auto hasExplicitIndex = + splitNameAndIndex(inSemanticName.getUnownedSlice(), semanticName, semanticIndex); + if (!hasExplicitIndex && optionalSemanticIndex) + semanticIndex = optionalSemanticIndex->getUnownedSlice(); + + result.wgslSystemValueNameEnum = convertSystemValueSemanticNameToEnum(semanticName); + + switch (result.wgslSystemValueNameEnum) { - if (fromVector->getElementCount() != toVector->getElementCount()) + case SystemValueSemanticName::Position: { - fromType = builder.getVectorType( - fromVector->getElementType(), - toVector->getElementCount()); - val = builder.emitVectorReshape(fromType, val); + result.wgslSystemValueName = toSlice("position"); + result.permittedTypes.add(builder.getVectorType( + builder.getBasicType(BaseType::Float), + builder.getIntValue(builder.getIntType(), 4))); + break; + } + + case SystemValueSemanticName::DispatchThreadID: + { + result.wgslSystemValueName = toSlice("global_invocation_id"); + IRType* const vec3uType{builder.getVectorType( + builder.getBasicType(BaseType::UInt), + builder.getIntValue(builder.getIntType(), 3))}; + result.permittedTypes.add(vec3uType); + } + break; + + case SystemValueSemanticName::GroupID: + { + result.wgslSystemValueName = toSlice("workgroup_id"); + result.permittedTypes.add(builder.getVectorType( + builder.getBasicType(BaseType::UInt), + builder.getIntValue(builder.getIntType(), 3))); + } + break; + + case SystemValueSemanticName::GroupThreadID: + { + result.wgslSystemValueName = toSlice("local_invocation_id"); + result.permittedTypes.add(builder.getVectorType( + builder.getBasicType(BaseType::UInt), + builder.getIntValue(builder.getIntType(), 3))); + } + break; + + case SystemValueSemanticName::GSInstanceID: + { + // No Geometry shaders in WGSL + result.isUnsupported = true; + } + break; + + case SystemValueSemanticName::GroupIndex: + { + result.wgslSystemValueName = toSlice("local_invocation_index"); + result.permittedTypes.add(builder.getUIntType()); + } + break; + + default: + { + m_sink->diagnose( + parentVar, + Diagnostics::unimplementedSystemValueSemantic, + semanticName); + return result; } } - else if (as<IRBasicType>(toType)) - { - UInt index = 0; - val = builder.emitSwizzle(fromVector->getElementType(), val, 1, &index); - if (toType->getOp() == kIROp_VoidType) - return nullptr; - } + + return result; } - else if (auto fromBasicType = as<IRBasicType>(fromType)) + + void reportUnsupportedSystemAttribute(IRInst* param, String semanticName) { - if (fromBasicType->getOp() == kIROp_VoidType) - return nullptr; - if (!as<IRBasicType>(toType)) - return nullptr; - if (toType->getOp() == kIROp_VoidType) - return nullptr; + m_sink->diagnose( + param->sourceLoc, + Diagnostics::systemValueAttributeNotSupported, + semanticName); } - else + + void ensureResultStructHasUserSemantic(IRStructType* structType, IRVarLayout* varLayout) { - return nullptr; + // Ensure each field in an output struct type has either a system semantic or a user + // semantic, so that signature matching can happen correctly. + auto typeLayout = as<IRStructTypeLayout>(varLayout->getTypeLayout()); + Index index = 0; + IRBuilder builder(structType); + for (auto field : structType->getFields()) + { + auto key = field->getKey(); + if (auto semanticDecor = key->findDecoration<IRSemanticDecoration>()) + { + if (semanticDecor->getSemanticName().startsWithCaseInsensitive(toSlice("sv_"))) + { + auto indexAsString = String(UInt(semanticDecor->getSemanticIndex())); + auto sysValInfo = + getSystemValueInfo(semanticDecor->getSemanticName(), &indexAsString, field); + if (sysValInfo.isUnsupported) + { + reportUnsupportedSystemAttribute(field, semanticDecor->getSemanticName()); + } + else + { + builder.addTargetSystemValueDecoration( + key, + sysValInfo.wgslSystemValueName.getUnownedSlice()); + semanticDecor->removeAndDeallocate(); + } + } + index++; + continue; + } + typeLayout->getFieldLayout(index); + auto fieldLayout = typeLayout->getFieldLayout(index); + if (auto offsetAttr = fieldLayout->findOffsetAttr(LayoutResourceKind::VaryingOutput)) + { + UInt varOffset = 0; + if (auto varOffsetAttr = + varLayout->findOffsetAttr(LayoutResourceKind::VaryingOutput)) + varOffset = varOffsetAttr->getOffset(); + varOffset += offsetAttr->getOffset(); + builder.addSemanticDecoration(key, toSlice("_slang_attr"), (int)varOffset); + } + index++; + } } - return builder.emitCast(toType, val); -} + // Stores a hicharchy of members and children which map 'oldStruct->member' to + // 'flatStruct->member' Note: this map assumes we map to FlatStruct since it is easier/faster to + // process + struct MapStructToFlatStruct + { + /* + We need a hicharchy map to resolve dependencies for mapping + oldStruct to newStruct efficently. Example: -WGSLSystemValueInfo LegalizeWGSLEntryPointContext::getSystemValueInfo( - String inSemanticName, - String* optionalSemanticIndex, - IRInst* parentVar) -{ - IRBuilder builder(m_module); - WGSLSystemValueInfo result = {}; - UnownedStringSlice semanticName; - UnownedStringSlice semanticIndex; + MyStruct + | + / | \ + / | \ + / | \ + M0<A> M1<A> M2<B> + | | | + A_0 A_0 B_0 - auto hasExplicitIndex = - splitNameAndIndex(inSemanticName.getUnownedSlice(), semanticName, semanticIndex); - if (!hasExplicitIndex && optionalSemanticIndex) - semanticIndex = optionalSemanticIndex->getUnownedSlice(); + Without storing hicharchy information, there will be no way to tell apart + `myStruct.M0.A0` from `myStruct.M1.A0` since IRStructKey/IRStructField + only has 1 instance of `A::A0` + */ - result.wgslSystemValueNameEnum = convertSystemValueSemanticNameToEnum(semanticName); + enum CopyOptions : int + { + // Copy a flattened-struct into a struct + FlatStructIntoStruct = 0, - switch (result.wgslSystemValueNameEnum) - { + // Copy a struct into a flattened-struct + StructIntoFlatStruct = 1, + }; + + private: + // Children of member if applicable. + Dictionary<IRStructField*, MapStructToFlatStruct> members; - case SystemValueSemanticName::DispatchThreadID: + // Field correlating to MapStructToFlatStruct Node. + IRInst* node; + IRStructKey* getKey() { - result.wgslSystemValueName = toSlice("global_invocation_id"); - IRType* const vec3uType{builder.getVectorType( - builder.getBasicType(BaseType::UInt), - builder.getIntValue(builder.getIntType(), 3))}; - result.permittedTypes.add(vec3uType); + SLANG_ASSERT(as<IRStructField>(node)); + return as<IRStructField>(node)->getKey(); } - break; - - case SystemValueSemanticName::GroupID: + IRInst* getNode() { return node; } + IRType* getFieldType() { - result.wgslSystemValueName = toSlice("workgroup_id"); - result.permittedTypes.add(builder.getVectorType( - builder.getBasicType(BaseType::UInt), - builder.getIntValue(builder.getIntType(), 3))); + SLANG_ASSERT(as<IRStructField>(node)); + return as<IRStructField>(node)->getFieldType(); } - break; - case SystemValueSemanticName::GroupThreadID: + // Whom node maps to inside target flatStruct + IRStructField* targetMapping; + + auto begin() { return members.begin(); } + auto end() { return members.end(); } + + // Copies members of oldStruct to/from newFlatStruct. Assumes members of val1 maps to + // members in val2 using `MapStructToFlatStruct` + template<int copyOptions> + static void _emitCopy( + IRBuilder& builder, + IRInst* val1, + IRStructType* type1, + IRInst* val2, + IRStructType* type2, + MapStructToFlatStruct& node) { - result.wgslSystemValueName = toSlice("local_invocation_id"); - result.permittedTypes.add(builder.getVectorType( - builder.getBasicType(BaseType::UInt), - builder.getIntValue(builder.getIntType(), 3))); + for (auto& field1Pair : node) + { + auto& field1 = field1Pair.second; + + // Get member of val1 + IRInst* fieldAddr1 = nullptr; + if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct) + { + fieldAddr1 = builder.emitFieldAddress(type1, val1, field1.getKey()); + } + else + { + if (as<IRPtrTypeBase>(val1)) + val1 = builder.emitLoad(val1); + fieldAddr1 = builder.emitFieldExtract(type1, val1, field1.getKey()); + } + + // If val1 is a struct, recurse + if (auto fieldAsStruct1 = as<IRStructType>(field1.getFieldType())) + { + _emitCopy<copyOptions>( + builder, + fieldAddr1, + fieldAsStruct1, + val2, + type2, + field1); + continue; + } + + // Get member of val2 which maps to val1.member + auto field2 = field1.getMapping(); + SLANG_ASSERT(field2); + IRInst* fieldAddr2 = nullptr; + if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct) + { + if (as<IRPtrTypeBase>(val2)) + val2 = builder.emitLoad(val1); + fieldAddr2 = builder.emitFieldExtract(type2, val2, field2->getKey()); + } + else + { + fieldAddr2 = builder.emitFieldAddress(type2, val2, field2->getKey()); + } + + // Copy val2/val1 member into val1/val2 member + if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct) + { + builder.emitStore(fieldAddr1, fieldAddr2); + } + else + { + builder.emitStore(fieldAddr2, fieldAddr1); + } + } } - break; - case SystemValueSemanticName::GSInstanceID: + public: + void setNode(IRInst* newNode) { node = newNode; } + // Get 'MapStructToFlatStruct' that is a child of 'parent'. + // Make 'MapStructToFlatStruct' if no 'member' is currently mapped to 'parent'. + MapStructToFlatStruct& getMember(IRStructField* member) { return members[member]; } + MapStructToFlatStruct& operator[](IRStructField* member) { return getMember(member); } + + void setMapping(IRStructField* newTargetMapping) { targetMapping = newTargetMapping; } + // Get 'MapStructToFlatStruct' that is a child of 'parent'. + // Return nullptr if no member is mapped to 'parent' + IRStructField* getMapping() { return targetMapping; } + + // Copies srcVal into dstVal using hicharchy map. + template<int copyOptions> + void emitCopy(IRBuilder& builder, IRInst* dstVal, IRInst* srcVal) { - // No Geometry shaders in WGSL - result.isUnsupported = true; + auto dstType = dstVal->getDataType(); + if (auto dstPtrType = as<IRPtrTypeBase>(dstType)) + dstType = dstPtrType->getValueType(); + auto dstStructType = as<IRStructType>(dstType); + SLANG_ASSERT(dstStructType); + + auto srcType = srcVal->getDataType(); + if (auto srcPtrType = as<IRPtrTypeBase>(srcType)) + srcType = srcPtrType->getValueType(); + auto srcStructType = as<IRStructType>(srcType); + SLANG_ASSERT(srcStructType); + + if constexpr (copyOptions == (int)CopyOptions::FlatStructIntoStruct) + { + // CopyOptions::FlatStructIntoStruct copy a flattened-struct (mapped member) into a + // struct + SLANG_ASSERT(node == dstStructType); + _emitCopy<copyOptions>( + builder, + dstVal, + dstStructType, + srcVal, + srcStructType, + *this); + } + else + { + // CopyOptions::StructIntoFlatStruct copy a struct into a flattened-struct + SLANG_ASSERT(node == srcStructType); + _emitCopy<copyOptions>( + builder, + srcVal, + srcStructType, + dstVal, + dstStructType, + *this); + } } - break; + }; - case SystemValueSemanticName::GroupIndex: + IRStructType* _flattenNestedStructs( + IRBuilder& builder, + IRStructType* dst, + IRStructType* src, + IRSemanticDecoration* parentSemanticDecoration, + IRLayoutDecoration* parentLayout, + MapStructToFlatStruct& mapFieldToField, + HashSet<IRStructField*>& varsWithSemanticInfo) + { + // For all fields ('oldField') of a struct do the following: + // 1. Check for 'decorations which carry semantic info' (IRSemanticDecoration, + // IRLayoutDecoration), store these if found. + // * Do not propagate semantic info if the current node has *any* form of semantic + // information. + // Update varsWithSemanticInfo. + // 2. If IRStructType: + // 2a. Recurse this function with 'decorations that carry semantic info' from parent. + // 3. If not IRStructType: + // 3a. Emit 'newField' equal to 'oldField', add 'decorations which carry semantic info'. + // 3b. Store a mapping from 'oldField' to 'newField' in 'mapFieldToField'. This info is + // needed to copy between types. + for (auto oldField : src->getFields()) { - result.wgslSystemValueName = toSlice("local_invocation_index"); - result.permittedTypes.add(builder.getUIntType()); + auto& fieldMappingNode = mapFieldToField[oldField]; + fieldMappingNode.setNode(oldField); + + // step 1 + bool foundSemanticDecor = false; + auto oldKey = oldField->getKey(); + IRSemanticDecoration* fieldSemanticDecoration = parentSemanticDecoration; + if (auto oldSemanticDecoration = oldKey->findDecoration<IRSemanticDecoration>()) + { + foundSemanticDecor = true; + fieldSemanticDecoration = oldSemanticDecoration; + parentLayout = nullptr; + } + + IRLayoutDecoration* fieldLayout = parentLayout; + if (auto oldLayout = oldKey->findDecoration<IRLayoutDecoration>()) + { + fieldLayout = oldLayout; + if (!foundSemanticDecor) + fieldSemanticDecoration = nullptr; + } + if (fieldSemanticDecoration != parentSemanticDecoration || parentLayout != fieldLayout) + varsWithSemanticInfo.add(oldField); + + // step 2a + if (auto structFieldType = as<IRStructType>(oldField->getFieldType())) + { + _flattenNestedStructs( + builder, + dst, + structFieldType, + fieldSemanticDecoration, + fieldLayout, + fieldMappingNode, + varsWithSemanticInfo); + continue; + } + + // step 3a + auto newKey = builder.createStructKey(); + copyNameHintAndDebugDecorations(newKey, oldKey); + + auto newField = builder.createStructField(dst, newKey, oldField->getFieldType()); + copyNameHintAndDebugDecorations(newField, oldField); + + if (fieldSemanticDecoration) + builder.addSemanticDecoration( + newKey, + fieldSemanticDecoration->getSemanticName(), + fieldSemanticDecoration->getSemanticIndex()); + + if (fieldLayout) + { + IRLayout* oldLayout = fieldLayout->getLayout(); + List<IRInst*> instToCopy; + // Only copy certain decorations needed for resolving system semantics + for (UInt i = 0; i < oldLayout->getOperandCount(); i++) + { + auto operand = oldLayout->getOperand(i); + if (as<IRVarOffsetAttr>(operand) || as<IRUserSemanticAttr>(operand) || + as<IRSystemValueSemanticAttr>(operand) || as<IRStageAttr>(operand)) + instToCopy.add(operand); + } + IRVarLayout* newLayout = builder.getVarLayout(instToCopy); + builder.addLayoutDecoration(newKey, newLayout); + } + // step 3b + fieldMappingNode.setMapping(newField); } - break; - default: + return dst; + } + + // Returns a `IRStructType*` without any `IRStructType*` members. `src` may be returned if there + // was no struct flattening. + // @param mapFieldToField Behavior maps all `IRStructField` of `src` to the new struct + // `IRStructFields`s + IRStructType* maybeFlattenNestedStructs( + IRBuilder& builder, + IRStructType* src, + MapStructToFlatStruct& mapFieldToField, + HashSet<IRStructField*>& varsWithSemanticInfo) + { + // Find all values inside struct that need flattening and legalization. + bool hasStructTypeMembers = false; + for (auto field : src->getFields()) { - m_sink->diagnose( - parentVar, - Diagnostics::unimplementedSystemValueSemantic, - semanticName); - return result; + if (as<IRStructType>(field->getFieldType())) + { + hasStructTypeMembers = true; + break; + } } + if (!hasStructTypeMembers) + return src; + + // We need to: + // 1. Make new struct 1:1 with old struct but without nestested structs (flatten) + // 2. Ensure semantic attributes propegate. This will create overlapping semantics (can be + // handled later). + // 3. Store the mapping from old to new struct fields to allow copying a old-struct to + // new-struct. + builder.setInsertAfter(src); + auto newStruct = builder.createStructType(); + copyNameHintAndDebugDecorations(newStruct, src); + mapFieldToField.setNode(src); + return _flattenNestedStructs( + builder, + newStruct, + src, + nullptr, + nullptr, + mapFieldToField, + varsWithSemanticInfo); } - return result; -} + // Replaces all 'IRReturn' by copying the current 'IRReturn' to a new var of type 'newType'. + // Copying logic from 'IRReturn' to 'newType' is controlled by 'copyLogicFunc' function. + template<typename CopyLogicFunc> + void _replaceAllReturnInst( + IRBuilder& builder, + IRFunc* targetFunc, + IRStructType* newType, + CopyLogicFunc copyLogicFunc) + { + for (auto block : targetFunc->getBlocks()) + { + if (auto returnInst = as<IRReturn>(block->getTerminator())) + { + builder.setInsertBefore(returnInst); + auto returnVal = returnInst->getVal(); + returnInst->setOperand(0, copyLogicFunc(builder, newType, returnVal)); + } + } + } -std::optional<SystemValLegalizationWorkItem> LegalizeWGSLEntryPointContext::makeSystemValWorkItem( - IRInst* var) -{ - if (auto semanticDecoration = var->findDecoration<IRSemanticDecoration>()) + UInt _returnNonOverlappingAttributeIndex(std::set<UInt>& usedSemanticIndex) { - bool svPrefix = - semanticDecoration->getSemanticName().startsWithCaseInsensitive(toSlice("sv_")); - if (svPrefix) + // Find first unused semantic index of equal semantic type + // to fill any gaps in user set semantic bindings + UInt prev = 0; + for (auto i : usedSemanticIndex) { - return { - {var, - String(semanticDecoration->getSemanticName()).toLower(), - (UInt)semanticDecoration->getSemanticIndex()}}; + if (i > prev + 1) + { + break; + } + prev = i; } + usedSemanticIndex.insert(prev + 1); + return prev + 1; } - auto layoutDecor = var->findDecoration<IRLayoutDecoration>(); - if (!layoutDecor) - return {}; - auto sysValAttr = layoutDecor->findAttr<IRSystemValueSemanticAttr>(); - if (!sysValAttr) - return {}; - auto semanticName = String(sysValAttr->getName()); - auto sysAttrIndex = sysValAttr->getIndex(); + template<typename T> + struct AttributeParentPair + { + IRLayoutDecoration* layoutDecor; + T* attr; + }; - return {{var, semanticName, sysAttrIndex}}; -} + IRLayoutDecoration* _replaceAttributeOfLayout( + IRBuilder& builder, + IRLayoutDecoration* parentLayoutDecor, + IRInst* instToReplace, + IRInst* instToReplaceWith) + { + // Replace `instToReplace` with a `instToReplaceWith` -List<SystemValLegalizationWorkItem> LegalizeWGSLEntryPointContext::collectSystemValFromEntryPoint( - EntryPointInfo entryPoint) -{ - List<SystemValLegalizationWorkItem> systemValWorkItems; - for (auto param : entryPoint.entryPointFunc->getParams()) + auto layout = parentLayoutDecor->getLayout(); + // Find the exact same decoration `instToReplace` in-case multiple of the same type exist + List<IRInst*> opList; + opList.add(instToReplaceWith); + for (UInt i = 0; i < layout->getOperandCount(); i++) + { + if (layout->getOperand(i) != instToReplace) + opList.add(layout->getOperand(i)); + } + auto newLayoutDecor = builder.addLayoutDecoration( + parentLayoutDecor->getParent(), + builder.getVarLayout(opList)); + parentLayoutDecor->removeAndDeallocate(); + return newLayoutDecor; + } + + IRLayoutDecoration* _simplifyUserSemanticNames( + IRBuilder& builder, + IRLayoutDecoration* layoutDecor) { - auto maybeWorkItem = makeSystemValWorkItem(param); - if (maybeWorkItem.has_value()) - systemValWorkItems.add(std::move(maybeWorkItem.value())); + // Ensure all 'ExplicitIndex' semantics such as "SV_TARGET0" are simplified into + // ("SV_TARGET", 0) using 'IRUserSemanticAttr' This is done to ensure we can check semantic + // groups using 'IRUserSemanticAttr1->getName() == IRUserSemanticAttr2->getName()' + SLANG_ASSERT(layoutDecor); + auto layout = layoutDecor->getLayout(); + List<IRInst*> layoutOps; + layoutOps.reserve(3); + bool changed = false; + for (auto attr : layout->getAllAttrs()) + { + if (auto userSemantic = as<IRUserSemanticAttr>(attr)) + { + UnownedStringSlice outName; + UnownedStringSlice outIndex; + bool hasStringIndex = splitNameAndIndex(userSemantic->getName(), outName, outIndex); + + changed = true; + auto newDecoration = builder.getUserSemanticAttr( + userSemanticName, + hasStringIndex ? stringToInt(outIndex) : 0); + userSemantic->replaceUsesWith(newDecoration); + userSemantic->removeAndDeallocate(); + userSemantic = newDecoration; + + layoutOps.add(userSemantic); + continue; + } + layoutOps.add(attr); + } + if (changed) + { + auto parent = layoutDecor->parent; + layoutDecor->removeAndDeallocate(); + builder.addLayoutDecoration(parent, builder.getVarLayout(layoutOps)); + } + return layoutDecor; } - return systemValWorkItems; -} -void LegalizeWGSLEntryPointContext::legalizeSystemValue( - EntryPointInfo entryPoint, - SystemValLegalizationWorkItem& workItem) -{ - IRBuilder builder(entryPoint.entryPointFunc); + // Find overlapping field semantics and legalize them + void fixFieldSemanticsOfFlatStruct(IRStructType* structType) + { + // Goal is to ensure we do not have overlapping semantics for the user defined semantics: + // Note that in WGSL, the semantics can be either `builtin` without index or `location` with + // index. + /* + // Assume the following code + struct Fragment + { + float4 p0 : SV_POSITION; + float2 p1 : TEXCOORD0; + float2 p2 : TEXCOORD1; + float3 p3 : COLOR0; + float3 p4 : COLOR1; + }; - auto var = workItem.var; - auto semanticName = workItem.attrName; + // Translates into + struct Fragment + { + float4 p0 : BUILTIN_POSITION; + float2 p1 : LOCATION_0; + float2 p2 : LOCATION_1; + float3 p3 : LOCATION_2; + float3 p4 : LOCATION_3; + }; + */ - auto indexAsString = String(workItem.attrIndex); - auto info = getSystemValueInfo(semanticName, &indexAsString, var); + // For Multi-Render-Target, the semantic index must be translated to `location` with + // the same index. Assume the following code + /* + struct Fragment + { + float4 p0 : SV_TARGET1; + float4 p1 : SV_TARGET0; + }; - if (!info.permittedTypes.getCount()) - return; + // Translates into + struct Fragment + { + float4 p0 : LOCATION_1; + float4 p1 : LOCATION_0; + }; + */ - builder.addTargetSystemValueDecoration(var, info.wgslSystemValueName.getUnownedSlice()); + IRBuilder builder(this->m_module); - bool varTypeIsPermitted = false; - auto varType = var->getFullType(); - for (auto& permittedType : info.permittedTypes) - { - varTypeIsPermitted = varTypeIsPermitted || permittedType == varType; + List<IRSemanticDecoration*> overlappingSemanticsDecor; + Dictionary<UnownedStringSlice, std::set<UInt, std::less<UInt>>> + usedSemanticIndexSemanticDecor; + + List<AttributeParentPair<IRVarOffsetAttr>> overlappingVarOffset; + Dictionary<UInt, std::set<UInt, std::less<UInt>>> usedSemanticIndexVarOffset; + + List<AttributeParentPair<IRUserSemanticAttr>> overlappingUserSemantic; + Dictionary<UnownedStringSlice, std::set<UInt, std::less<UInt>>> + usedSemanticIndexUserSemantic; + + // We store a map from old `IRLayoutDecoration*` to new `IRLayoutDecoration*` since when + // legalizing we may destroy and remake a `IRLayoutDecoration*` + Dictionary<IRLayoutDecoration*, IRLayoutDecoration*> oldLayoutDecorToNew; + + // Collect all "semantic info carrying decorations". Any collected decoration will + // fill up their respective 'Dictionary<SEMANTIC_TYPE, OrderedHashSet<UInt>>' + // to keep track of in-use offsets for a semantic type. + // Example: IRSemanticDecoration with name of "SV_TARGET1". + // * This will have SEMANTIC_TYPE of "sv_target". + // * This will use up index '1' + // + // Now if a second equal semantic "SV_TARGET1" is found, we add this decoration to + // a list of 'overlapping semantic info decorations' so we can legalize this + // 'semantic info decoration' later. + // + // NOTE: this is a flat struct, all members are children of the initial + // IRStructType. + for (auto field : structType->getFields()) + { + auto key = field->getKey(); + if (auto semanticDecoration = key->findDecoration<IRSemanticDecoration>()) + { + auto semanticName = semanticDecoration->getSemanticName(); + + // sv_target is treated as a user-semantic because it should be emitted with + // @location like how the user semantics are emitted. + // For fragment shader, only sv_target will user @location, and for non-fragment + // shaders, sv_target is not valid. + bool isUserSemantic = + (semanticName.startsWithCaseInsensitive(toSlice("sv_target")) || + !semanticName.startsWithCaseInsensitive(toSlice("sv_"))); + + // Ensure names are in a uniform lowercase format so we can bunch together simmilar + // semantics. + UnownedStringSlice outName; + UnownedStringSlice outIndex; + bool hasStringIndex = splitNameAndIndex(semanticName, outName, outIndex); + + // user semantics gets all same semantic-name. + auto loweredName = String(outName).toLower(); + auto loweredNameSlice = + isUserSemantic ? userSemanticName : loweredName.getUnownedSlice(); + auto newDecoration = builder.addSemanticDecoration( + key, + loweredNameSlice, + hasStringIndex ? stringToInt(outIndex) : 0); + semanticDecoration->replaceUsesWith(newDecoration); + semanticDecoration->removeAndDeallocate(); + semanticDecoration = newDecoration; + + auto& semanticUse = + usedSemanticIndexSemanticDecor[semanticDecoration->getSemanticName()]; + if (semanticUse.find(semanticDecoration->getSemanticIndex()) != semanticUse.end()) + overlappingSemanticsDecor.add(semanticDecoration); + else + semanticUse.insert(semanticDecoration->getSemanticIndex()); + } + if (auto layoutDecor = key->findDecoration<IRLayoutDecoration>()) + { + // Ensure names are in a uniform lowercase format so we can bunch together simmilar + // semantics + layoutDecor = _simplifyUserSemanticNames(builder, layoutDecor); + oldLayoutDecorToNew[layoutDecor] = layoutDecor; + auto layout = layoutDecor->getLayout(); + for (auto attr : layout->getAllAttrs()) + { + if (auto offset = as<IRVarOffsetAttr>(attr)) + { + auto& semanticUse = usedSemanticIndexVarOffset[offset->getResourceKind()]; + if (semanticUse.find(offset->getOffset()) != semanticUse.end()) + overlappingVarOffset.add({layoutDecor, offset}); + else + semanticUse.insert(offset->getOffset()); + } + else if (auto userSemantic = as<IRUserSemanticAttr>(attr)) + { + auto& semanticUse = usedSemanticIndexUserSemantic[userSemantic->getName()]; + if (semanticUse.find(userSemantic->getIndex()) != semanticUse.end()) + overlappingUserSemantic.add({layoutDecor, userSemantic}); + else + semanticUse.insert(userSemantic->getIndex()); + } + } + } + } + + // Legalize all overlapping 'semantic info decorations' + for (auto decor : overlappingSemanticsDecor) + { + auto newOffset = _returnNonOverlappingAttributeIndex( + usedSemanticIndexSemanticDecor[decor->getSemanticName()]); + builder.addSemanticDecoration( + decor->getParent(), + decor->getSemanticName(), + (int)newOffset); + decor->removeAndDeallocate(); + } + for (auto& varOffset : overlappingVarOffset) + { + auto newOffset = _returnNonOverlappingAttributeIndex( + usedSemanticIndexVarOffset[varOffset.attr->getResourceKind()]); + auto newVarOffset = builder.getVarOffsetAttr( + varOffset.attr->getResourceKind(), + newOffset, + varOffset.attr->getSpace()); + oldLayoutDecorToNew[varOffset.layoutDecor] = _replaceAttributeOfLayout( + builder, + oldLayoutDecorToNew[varOffset.layoutDecor], + varOffset.attr, + newVarOffset); + } + for (auto& userSemantic : overlappingUserSemantic) + { + auto newOffset = _returnNonOverlappingAttributeIndex( + usedSemanticIndexUserSemantic[userSemantic.attr->getName()]); + auto newUserSemantic = + builder.getUserSemanticAttr(userSemantic.attr->getName(), newOffset); + oldLayoutDecorToNew[userSemantic.layoutDecor] = _replaceAttributeOfLayout( + builder, + oldLayoutDecorToNew[userSemantic.layoutDecor], + userSemantic.attr, + newUserSemantic); + } } - if (!varTypeIsPermitted) + void wrapReturnValueInStruct(EntryPointInfo entryPoint) { - // Note: we do not currently prefer any conversion - // example: - // * allowed types for semantic: `float4`, `uint4`, `int4` - // * user used, `float2` - // * Slang will equally prefer `float4` to `uint4` to `int4`. - // This means the type may lose data if slang selects `uint4` or `int4`. - bool foundAConversion = false; - for (auto permittedType : info.permittedTypes) - { - var->setFullType(permittedType); - builder.setInsertBefore( - entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst()); - - // get uses before we `tryConvertValue` since this creates a new use - List<IRUse*> uses; - for (auto use = var->firstUse; use; use = use->nextUse) - uses.add(use); - - auto convertedValue = tryConvertValue(builder, var, varType); - if (convertedValue == nullptr) - continue; + // Wrap return value into a struct if it is not already a struct. + // For example, given this entry point: + // ``` + // float4 main() : SV_Target { return float3(1,2,3); } + // ``` + // We are going to transform it into: + // ``` + // struct Output { + // float4 value : SV_Target; + // }; + // Output main() { return {float3(1,2,3)}; } - foundAConversion = true; - copyNameHintAndDebugDecorations(convertedValue, var); + auto func = entryPoint.entryPointFunc; - for (auto use : uses) - builder.replaceOperand(use, convertedValue); + auto returnType = func->getResultType(); + if (as<IRVoidType>(returnType)) + return; + auto entryPointLayoutDecor = func->findDecoration<IRLayoutDecoration>(); + if (!entryPointLayoutDecor) + return; + auto entryPointLayout = as<IREntryPointLayout>(entryPointLayoutDecor->getLayout()); + if (!entryPointLayout) + return; + auto resultLayout = entryPointLayout->getResultLayout(); + + // If return type is already a struct, just make sure every field has a semantic. + if (auto returnStructType = as<IRStructType>(returnType)) + { + IRBuilder builder(func); + MapStructToFlatStruct mapOldFieldToNewField; + // Flatten result struct type to ensure we do not have nested semantics + auto flattenedStruct = maybeFlattenNestedStructs( + builder, + returnStructType, + mapOldFieldToNewField, + semanticInfoToRemove); + if (returnStructType != flattenedStruct) + { + // Replace all return-values with the flattenedStruct we made. + _replaceAllReturnInst( + builder, + func, + flattenedStruct, + [&](IRBuilder& copyBuilder, IRStructType* dstType, IRInst* srcVal) -> IRInst* + { + auto srcStructType = as<IRStructType>(srcVal->getDataType()); + SLANG_ASSERT(srcStructType); + auto dstVal = copyBuilder.emitVar(dstType); + mapOldFieldToNewField.emitCopy<( + int)MapStructToFlatStruct::CopyOptions::StructIntoFlatStruct>( + copyBuilder, + dstVal, + srcVal); + return builder.emitLoad(dstVal); + }); + fixUpFuncType(func, flattenedStruct); + } + // Ensure non-overlapping semantics + fixFieldSemanticsOfFlatStruct(flattenedStruct); + ensureResultStructHasUserSemantic(flattenedStruct, resultLayout); + return; } - if (!foundAConversion) + + IRBuilder builder(func); + builder.setInsertBefore(func); + IRStructType* structType = builder.createStructType(); + auto stageText = getStageText(entryPoint.entryPointDecor->getProfile().getStage()); + builder.addNameHintDecoration( + structType, + (String(stageText) + toSlice("Output")).getUnownedSlice()); + auto key = builder.createStructKey(); + builder.addNameHintDecoration(key, toSlice("output")); + builder.addLayoutDecoration(key, resultLayout); + builder.createStructField(structType, key, returnType); + IRStructTypeLayout::Builder structTypeLayoutBuilder(&builder); + structTypeLayoutBuilder.addField(key, resultLayout); + auto typeLayout = structTypeLayoutBuilder.build(); + IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout); + auto varLayout = varLayoutBuilder.build(); + ensureResultStructHasUserSemantic(structType, varLayout); + + _replaceAllReturnInst( + builder, + func, + structType, + [](IRBuilder& copyBuilder, IRStructType* dstType, IRInst* srcVal) -> IRInst* + { return copyBuilder.emitMakeStruct(dstType, 1, &srcVal); }); + + // Assign an appropriate system value semantic for stage output + auto stage = entryPoint.entryPointDecor->getProfile().getStage(); + switch (stage) { - // If we can't convert the value, report an error. - for (auto permittedType : info.permittedTypes) + case Stage::Compute: + case Stage::Fragment: { - StringBuilder typeNameSB; - getTypeNameHint(typeNameSB, permittedType); - m_sink->diagnose( - var->sourceLoc, - Diagnostics::systemValueTypeIncompatible, - semanticName, - typeNameSB.produceString()); + IRInst* operands[] = { + builder.getStringValue(userSemanticName), + builder.getIntValue(builder.getIntType(), 0)}; + builder.addDecoration( + key, + kIROp_SemanticDecoration, + operands, + SLANG_COUNT_OF(operands)); + break; } + case Stage::Vertex: + { + builder.addTargetSystemValueDecoration(key, toSlice("position")); + break; + } + default: + SLANG_ASSERT(false); + return; } - } -} -void LegalizeWGSLEntryPointContext::legalizeSystemValueParameters(EntryPointInfo entryPoint) -{ - List<SystemValLegalizationWorkItem> systemValWorkItems = - collectSystemValFromEntryPoint(entryPoint); + fixUpFuncType(func, structType); + } - for (auto index = 0; index < systemValWorkItems.getCount(); index++) + IRInst* tryConvertValue(IRBuilder& builder, IRInst* val, IRType* toType) { - legalizeSystemValue(entryPoint, systemValWorkItems[index]); + auto fromType = val->getFullType(); + if (auto fromVector = as<IRVectorType>(fromType)) + { + if (auto toVector = as<IRVectorType>(toType)) + { + if (fromVector->getElementCount() != toVector->getElementCount()) + { + fromType = builder.getVectorType( + fromVector->getElementType(), + toVector->getElementCount()); + val = builder.emitVectorReshape(fromType, val); + } + } + else if (as<IRBasicType>(toType)) + { + UInt index = 0; + val = builder.emitSwizzle(fromVector->getElementType(), val, 1, &index); + if (toType->getOp() == kIROp_VoidType) + return nullptr; + } + } + else if (auto fromBasicType = as<IRBasicType>(fromType)) + { + if (fromBasicType->getOp() == kIROp_VoidType) + return nullptr; + if (!as<IRBasicType>(toType)) + return nullptr; + if (toType->getOp() == kIROp_VoidType) + return nullptr; + } + else + { + return nullptr; + } + return builder.emitCast(toType, val); } -} -void LegalizeWGSLEntryPointContext::legalizeEntryPointForWGSL(EntryPointInfo entryPoint) -{ - legalizeSystemValueParameters(entryPoint); -} - -void LegalizeWGSLEntryPointContext::legalizeCall(IRCall* call) -{ - // WGSL does not allow forming a pointer to a sub part of a composite value. - // For example, if we have - // ``` - // struct S { float x; float y; }; - // void foo(inout float v) { v = 1.0f; } - // void main() { S s; foo(s.x); } - // ``` - // The call to `foo(s.x)` is illegal in WGSL because `s.x` is a sub part of `s`. - // And trying to form `&s.x` in WGSL is illegal. - // To work around this, we will create a local variable to hold the sub part of - // the composite value. - // And then pass the local variable to the function. - // After the call, we will write back the local variable to the sub part of the - // composite value. - // - IRBuilder builder(call); - builder.setInsertBefore(call); - struct WritebackPair + struct SystemValLegalizationWorkItem { - IRInst* dest; - IRInst* value; + IRInst* var; + IRType* varType; + String attrName; + UInt attrIndex; }; - ShortList<WritebackPair> pendingWritebacks; - for (UInt i = 0; i < call->getArgCount(); i++) + std::optional<SystemValLegalizationWorkItem> tryToMakeSystemValWorkItem( + IRInst* var, + IRType* varType) { - auto arg = call->getArg(i); - auto ptrType = as<IRPtrTypeBase>(arg->getDataType()); - if (!ptrType) - continue; - switch (arg->getOp()) + if (auto semanticDecoration = var->findDecoration<IRSemanticDecoration>()) { - case kIROp_Var: - case kIROp_Param: - continue; - default: - break; + if (semanticDecoration->getSemanticName().startsWithCaseInsensitive(toSlice("sv_"))) + { + return { + {var, + varType, + String(semanticDecoration->getSemanticName()).toLower(), + (UInt)semanticDecoration->getSemanticIndex()}}; + } + } + + auto layoutDecor = var->findDecoration<IRLayoutDecoration>(); + if (!layoutDecor) + return {}; + auto sysValAttr = layoutDecor->findAttr<IRSystemValueSemanticAttr>(); + if (!sysValAttr) + return {}; + auto semanticName = String(sysValAttr->getName()); + auto sysAttrIndex = sysValAttr->getIndex(); + + return {{var, varType, semanticName, sysAttrIndex}}; + } + + List<SystemValLegalizationWorkItem> collectSystemValFromEntryPoint(EntryPointInfo entryPoint) + { + List<SystemValLegalizationWorkItem> systemValWorkItems; + for (auto param : entryPoint.entryPointFunc->getParams()) + { + if (auto structType = as<IRStructType>(param->getDataType())) + { + for (auto field : structType->getFields()) + { + // Nested struct-s are flattened already by flattenInputParameters(). + SLANG_ASSERT(!as<IRStructType>(field->getFieldType())); + + auto key = field->getKey(); + auto fieldType = field->getFieldType(); + auto maybeWorkItem = tryToMakeSystemValWorkItem(key, fieldType); + if (maybeWorkItem.has_value()) + systemValWorkItems.add(std::move(maybeWorkItem.value())); + } + continue; + } + + auto maybeWorkItem = tryToMakeSystemValWorkItem(param, param->getFullType()); + if (maybeWorkItem.has_value()) + systemValWorkItems.add(std::move(maybeWorkItem.value())); } + return systemValWorkItems; + } + + void legalizeSystemValue(EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem) + { + IRBuilder builder(entryPoint.entryPointFunc); + + auto var = workItem.var; + auto varType = workItem.varType; + auto semanticName = workItem.attrName; - // Create a local variable to hold the input argument. - auto var = builder.emitVar(ptrType->getValueType(), AddressSpace::Function); + auto indexAsString = String(workItem.attrIndex); + auto info = getSystemValueInfo(semanticName, &indexAsString, var); - // Store the input argument into the local variable. - builder.emitStore(var, builder.emitLoad(arg)); - builder.replaceOperand(call->getArgs() + i, var); - pendingWritebacks.add({arg, var}); + if (info.isUnsupported) + { + reportUnsupportedSystemAttribute(var, semanticName); + return; + } + if (!info.permittedTypes.getCount()) + return; + + builder.addTargetSystemValueDecoration(var, info.wgslSystemValueName.getUnownedSlice()); + + bool varTypeIsPermitted = false; + for (auto& permittedType : info.permittedTypes) + { + varTypeIsPermitted = varTypeIsPermitted || permittedType == varType; + } + + if (!varTypeIsPermitted) + { + // Note: we do not currently prefer any conversion + // example: + // * allowed types for semantic: `float4`, `uint4`, `int4` + // * user used, `float2` + // * Slang will equally prefer `float4` to `uint4` to `int4`. + // This means the type may lose data if slang selects `uint4` or `int4`. + bool foundAConversion = false; + for (auto permittedType : info.permittedTypes) + { + var->setFullType(permittedType); + builder.setInsertBefore( + entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst()); + + // get uses before we `tryConvertValue` since this creates a new use + List<IRUse*> uses; + for (auto use = var->firstUse; use; use = use->nextUse) + uses.add(use); + + auto convertedValue = tryConvertValue(builder, var, varType); + if (convertedValue == nullptr) + continue; + + foundAConversion = true; + copyNameHintAndDebugDecorations(convertedValue, var); + + for (auto use : uses) + builder.replaceOperand(use, convertedValue); + } + if (!foundAConversion) + { + // If we can't convert the value, report an error. + for (auto permittedType : info.permittedTypes) + { + StringBuilder typeNameSB; + getTypeNameHint(typeNameSB, permittedType); + m_sink->diagnose( + var->sourceLoc, + Diagnostics::systemValueTypeIncompatible, + semanticName, + typeNameSB.produceString()); + } + } + } } - // Perform writebacks after the call. - builder.setInsertAfter(call); - for (auto& pair : pendingWritebacks) + void legalizeSystemValueParameters(EntryPointInfo entryPoint) { - builder.emitStore(pair.dest, builder.emitLoad(pair.value)); + List<SystemValLegalizationWorkItem> systemValWorkItems = + collectSystemValFromEntryPoint(entryPoint); + + for (auto index = 0; index < systemValWorkItems.getCount(); index++) + { + legalizeSystemValue(entryPoint, systemValWorkItems[index]); + } + fixUpFuncType(entryPoint.entryPointFunc); } -} -void LegalizeWGSLEntryPointContext::legalizeSwitch(IRSwitch* switchInst) -{ - // WGSL Requires all switch statements to contain a default case. - // If the switch statement does not contain a default case, we will add one. - if (switchInst->getDefaultLabel() != switchInst->getBreakLabel()) - return; - IRBuilder builder(switchInst); - auto defaultBlock = builder.createBlock(); - builder.setInsertInto(defaultBlock); - builder.emitBranch(switchInst->getBreakLabel()); - defaultBlock->insertBefore(switchInst->getBreakLabel()); - List<IRInst*> cases; - for (UInt i = 0; i < switchInst->getCaseCount(); i++) + void legalizeEntryPointForWGSL(EntryPointInfo entryPoint) { - cases.add(switchInst->getCaseValue(i)); - cases.add(switchInst->getCaseLabel(i)); + // Input Parameter Legalize + flattenInputParameters(entryPoint); + + // System Value Legalize + legalizeSystemValueParameters(entryPoint); + + // Output Value Legalize + wrapReturnValueInStruct(entryPoint); } - builder.setInsertBefore(switchInst); - auto newSwitch = builder.emitSwitch( - switchInst->getCondition(), - switchInst->getBreakLabel(), - defaultBlock, - (UInt)cases.getCount(), - cases.getBuffer()); - switchInst->transferDecorationsTo(newSwitch); - switchInst->removeAndDeallocate(); -} -void LegalizeWGSLEntryPointContext::legalizeBinaryOp(IRInst* inst) -{ - auto isVectorOrMatrix = [](IRType* type) + void legalizeCall(IRCall* call) { - switch (type->getOp()) + // WGSL does not allow forming a pointer to a sub part of a composite value. + // For example, if we have + // ``` + // struct S { float x; float y; }; + // void foo(inout float v) { v = 1.0f; } + // void main() { S s; foo(s.x); } + // ``` + // The call to `foo(s.x)` is illegal in WGSL because `s.x` is a sub part of `s`. + // And trying to form `&s.x` in WGSL is illegal. + // To work around this, we will create a local variable to hold the sub part of + // the composite value. + // And then pass the local variable to the function. + // After the call, we will write back the local variable to the sub part of the + // composite value. + // + IRBuilder builder(call); + builder.setInsertBefore(call); + struct WritebackPair { - case kIROp_VectorType: - case kIROp_MatrixType: - return true; - default: - return false; + IRInst* dest; + IRInst* value; + }; + ShortList<WritebackPair> pendingWritebacks; + + for (UInt i = 0; i < call->getArgCount(); i++) + { + auto arg = call->getArg(i); + auto ptrType = as<IRPtrTypeBase>(arg->getDataType()); + if (!ptrType) + continue; + switch (arg->getOp()) + { + case kIROp_Var: + case kIROp_Param: + continue; + default: + break; + } + + // Create a local variable to hold the input argument. + auto var = builder.emitVar(ptrType->getValueType(), AddressSpace::Function); + + // Store the input argument into the local variable. + builder.emitStore(var, builder.emitLoad(arg)); + builder.replaceOperand(call->getArgs() + i, var); + pendingWritebacks.add({arg, var}); } - }; - if (isVectorOrMatrix(inst->getOperand(0)->getDataType()) && - as<IRBasicType>(inst->getOperand(1)->getDataType())) + + // Perform writebacks after the call. + builder.setInsertAfter(call); + for (auto& pair : pendingWritebacks) + { + builder.emitStore(pair.dest, builder.emitLoad(pair.value)); + } + } + + void legalizeSwitch(IRSwitch* switchInst) { - IRBuilder builder(inst); - builder.setInsertBefore(inst); - auto newRhs = builder.emitMakeCompositeFromScalar( - inst->getOperand(0)->getDataType(), - inst->getOperand(1)); - builder.replaceOperand(inst->getOperands() + 1, newRhs); + // WGSL Requires all switch statements to contain a default case. + // If the switch statement does not contain a default case, we will add one. + if (switchInst->getDefaultLabel() != switchInst->getBreakLabel()) + return; + IRBuilder builder(switchInst); + auto defaultBlock = builder.createBlock(); + builder.setInsertInto(defaultBlock); + builder.emitBranch(switchInst->getBreakLabel()); + defaultBlock->insertBefore(switchInst->getBreakLabel()); + List<IRInst*> cases; + for (UInt i = 0; i < switchInst->getCaseCount(); i++) + { + cases.add(switchInst->getCaseValue(i)); + cases.add(switchInst->getCaseLabel(i)); + } + builder.setInsertBefore(switchInst); + auto newSwitch = builder.emitSwitch( + switchInst->getCondition(), + switchInst->getBreakLabel(), + defaultBlock, + (UInt)cases.getCount(), + cases.getBuffer()); + switchInst->transferDecorationsTo(newSwitch); + switchInst->removeAndDeallocate(); } - else if ( - as<IRBasicType>(inst->getOperand(0)->getDataType()) && - isVectorOrMatrix(inst->getOperand(1)->getDataType())) + + void legalizeBinaryOp(IRInst* inst) { - IRBuilder builder(inst); - builder.setInsertBefore(inst); - auto newLhs = builder.emitMakeCompositeFromScalar( - inst->getOperand(1)->getDataType(), - inst->getOperand(0)); - builder.replaceOperand(inst->getOperands(), newLhs); + auto isVectorOrMatrix = [](IRType* type) + { + switch (type->getOp()) + { + case kIROp_VectorType: + case kIROp_MatrixType: + return true; + default: + return false; + } + }; + if (isVectorOrMatrix(inst->getOperand(0)->getDataType()) && + as<IRBasicType>(inst->getOperand(1)->getDataType())) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto newRhs = builder.emitMakeCompositeFromScalar( + inst->getOperand(0)->getDataType(), + inst->getOperand(1)); + builder.replaceOperand(inst->getOperands() + 1, newRhs); + } + else if ( + as<IRBasicType>(inst->getOperand(0)->getDataType()) && + isVectorOrMatrix(inst->getOperand(1)->getDataType())) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto newLhs = builder.emitMakeCompositeFromScalar( + inst->getOperand(1)->getDataType(), + inst->getOperand(0)); + builder.replaceOperand(inst->getOperands(), newLhs); + } } -} -void LegalizeWGSLEntryPointContext::processInst(IRInst* inst) -{ - switch (inst->getOp()) + void processInst(IRInst* inst) { - case kIROp_Call: - legalizeCall(static_cast<IRCall*>(inst)); - break; - case kIROp_Switch: - legalizeSwitch(as<IRSwitch>(inst)); - break; - - // For all binary operators, make sure both side of the operator have the same type - // (vector-ness and matrix-ness). - case kIROp_Add: - case kIROp_Sub: - case kIROp_Mul: - case kIROp_Div: - case kIROp_FRem: - case kIROp_IRem: - case kIROp_And: - case kIROp_Or: - case kIROp_BitAnd: - case kIROp_BitOr: - case kIROp_BitXor: - case kIROp_Lsh: - case kIROp_Rsh: - case kIROp_Eql: - case kIROp_Neq: - case kIROp_Greater: - case kIROp_Less: - case kIROp_Geq: - case kIROp_Leq: - legalizeBinaryOp(inst); - break; - - default: - for (auto child : inst->getModifiableChildren()) - processInst(child); + switch (inst->getOp()) + { + case kIROp_Call: + legalizeCall(static_cast<IRCall*>(inst)); + break; + + case kIROp_Switch: + legalizeSwitch(as<IRSwitch>(inst)); + break; + + // For all binary operators, make sure both side of the operator have the same type + // (vector-ness and matrix-ness). + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_FRem: + case kIROp_IRem: + case kIROp_And: + case kIROp_Or: + case kIROp_BitAnd: + case kIROp_BitOr: + case kIROp_BitXor: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_Eql: + case kIROp_Neq: + case kIROp_Greater: + case kIROp_Less: + case kIROp_Geq: + case kIROp_Leq: + legalizeBinaryOp(inst); + break; + + default: + for (auto child : inst->getModifiableChildren()) + { + processInst(child); + } + } } -} +}; + void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink) { List<EntryPointInfo> entryPoints; @@ -484,7 +1419,9 @@ void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink) LegalizeWGSLEntryPointContext context(sink, module); for (auto entryPoint : entryPoints) + { context.legalizeEntryPointForWGSL(entryPoint); + } // Go through every instruction in the module and legalize them as needed. context.processInst(module->getModuleInst()); diff --git a/tests/compute/assoctype-lookup.slang b/tests/compute/assoctype-lookup.slang index 348391e21..8a032528b 100644 --- a/tests/compute/assoctype-lookup.slang +++ b/tests/compute/assoctype-lookup.slang @@ -16,8 +16,8 @@ struct StandardBoneWeightSet : IBoneWeightSet {
struct PackedType
{
- uint boneIds : BONEIDS;
- uint boneWeights : BONEWEIGHTS;
+ uint boneIds;
+ uint boneWeights;
};
PackedType field;
};
@@ -55,4 +55,4 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) int inputVal = int(tid);
int outputVal = test(inputVal);
gOutputBuffer[tid] = outputVal;
-}
\ No newline at end of file +}
diff --git a/tests/wgsl/nested-varying-input.slang b/tests/wgsl/nested-varying-input.slang new file mode 100644 index 000000000..2cdf4f7eb --- /dev/null +++ b/tests/wgsl/nested-varying-input.slang @@ -0,0 +1,53 @@ +//TEST:SIMPLE(filecheck=VERT): -target wgsl -stage vertex -entry vertexMain +//TEST:SIMPLE(filecheck=FRAG): -target wgsl -stage fragment -entry fragmentMain + +// Tests three aspects: +// 1. Flatten the nested struct for the return type of the entry-point +// 2. For fragment shader, SV_TARGET index must be emitted as @location(index) +// 3. For non-fragment shader, the user defined semantics should be emitted as @location(index) + +struct FragmentOutput +{ + //FRAG: @location(1) color1 + float4 color1 : SV_TARGET1; + + //FRAG: @location(0) color0 + float4 color0 : SV_TARGET0; +}; + +struct NestedVertexOutput +{ + float4 color : COLOR0; +}; + +struct VertexOutput +{ + //VERT: @builtin(position) position + //FRAG: @builtin(position) position + float4 position : SV_Position; + + //VERT: @location(0) uv + //FRAG: @location(0) uv + float2 uv : TEXCOORD0; + + //VERT: @location(1) color + //FRAG: @location(1) color + NestedVertexOutput nested; +}; + +VertexOutput vertexMain() +{ + VertexOutput out; + out.position = float4(1.0, 1.0, 1.0, 1.0); + out.uv = float2(0.5, 0.5); + out.nested.color = float4(0.0, 0.0, 0.0, 0.0); + return out; +} + +FragmentOutput fragmentMain(VertexOutput input) +{ + FragmentOutput out; + out.color0 = input.nested.color; + out.color1 = input.nested.color; + return out; +} |
