diff options
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-emit-metal.cpp | 101 | ||||
| -rw-r--r-- | source/slang/slang-emit-metal.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 19 | ||||
| -rw-r--r-- | source/slang/slang-ir-metal-legalize.cpp | 327 | ||||
| -rw-r--r-- | tests/metal/stage-in-2.slang | 70 | ||||
| -rw-r--r-- | tests/metal/stage-in.slang | 8 | ||||
| -rw-r--r-- | tests/metal/system-val-conversion.slang | 29 |
9 files changed, 473 insertions, 89 deletions
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 4cf6a899d..b011227fd 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -828,8 +828,11 @@ DIAGNOSTIC(55102, Error, invalidTorchKernelParamType, "'$0' is not a valid param DIAGNOSTIC(55200, Error, unsupportedBuiltinType, "'$0' is not a supported builtin type for the target.") DIAGNOSTIC(55201, Error, unsupportedRecursion, "recursion detected in call to '$0', but the current code generation target does not allow recursion.") +DIAGNOSTIC(55202, Error, systemValueAttributeNotSupported, "system value semantic '$0' is not supported for the current target.") +DIAGNOSTIC(55203, Error, systemValueTypeIncompatible, "system value semantic '$0' should have type '$1' or convertible to type '$1'.") DIAGNOSTIC(56001, Error, unableToAutoMapCUDATypeToHostType, "Could not automatically map '$0' to a host type. Automatic binding generation failed for '$1'") + DIAGNOSTIC(57001, Warning, spirvOptFailed, "spirv-opt failed. $0") // GLSL Compatibility diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp index 96843e286..366c51840 100644 --- a/source/slang/slang-emit-metal.cpp +++ b/source/slang/slang-emit-metal.cpp @@ -170,8 +170,11 @@ void MetalSourceEmitter::emitFuncParamLayoutImpl(IRInst* param) break; } } - if (auto sysSemanticAttr = layout->findSystemValueSemanticAttr()) - _emitSystemSemantic(sysSemanticAttr->getName(), sysSemanticAttr->getIndex()); + if (!maybeEmitSystemSemantic(param)) + { + if (auto sysSemanticAttr = layout->findSystemValueSemanticAttr()) + _emitUserSemantic(sysSemanticAttr->getName(), sysSemanticAttr->getIndex()); + } } void MetalSourceEmitter::emitParameterGroupImpl(IRGlobalParam* varDecl, IRUniformParameterGroupType* type) @@ -741,84 +744,24 @@ void MetalSourceEmitter::_emitType(IRType* type, DeclaratorInfo* declarator) } } -void MetalSourceEmitter::_emitSystemSemantic(UnownedStringSlice semanticName, IRIntegerValue semanticIndex) +bool MetalSourceEmitter::maybeEmitSystemSemantic(IRInst* inst) { - if (semanticName.caseInsensitiveEquals(toSlice("SV_POSITION"))) - { - m_writer->emit(" [[position]]"); - } - else if (semanticName.caseInsensitiveEquals(toSlice("SV_VERTEXID"))) - { - m_writer->emit(" [[vertex_id]]"); - } - else if (semanticName.caseInsensitiveEquals(toSlice("SV_INSTANCEID"))) - { - m_writer->emit(" [[instance_id]]"); - } - else if (semanticName.caseInsensitiveEquals(toSlice("SV_Target"))) - { - m_writer->emit(" [[color("); - m_writer->emit(semanticIndex); - m_writer->emit(")]]"); - } - else if (semanticName.caseInsensitiveEquals(toSlice("SV_PRIMITIVEID"))) + if (auto sysSemanticDecor = inst->findDecoration<IRTargetSystemValueDecoration>()) { - m_writer->emit(" [[primitive_id]]"); - } - else if (semanticName.caseInsensitiveEquals(toSlice("SV_GROUPID"))) - { - // TODO: not supported by metal. - // We need to implement the transformation logic in slang-ir-metal-legalize.cpp - // to convert SV_GroupID to something like METAL_threadgroup_position_in_grid. - } - else if (semanticName.caseInsensitiveEquals(toSlice("SV_GROUPINDEX"))) - { - // TODO: not supported by metal. - } - else if (semanticName.caseInsensitiveEquals(toSlice("SV_DISPATCHTHREADID"))) - { - m_writer->emit(" [[thread_position_in_grid]]"); - } - else if (semanticName.caseInsensitiveEquals(toSlice("SV_GROUPTHREADID"))) - { - m_writer->emit(" [[thread_position_in_threadgroup]]"); - } - else if (semanticName.caseInsensitiveEquals(toSlice("SV_CLIPDISTANCE"))) - { - m_writer->emit(" [[clip_distance]]"); - } - else if (semanticName.caseInsensitiveEquals(toSlice("SV_RENDERTARGETARRAYINDEX"))) - { - m_writer->emit(" [[render_target_array_index]]"); - } - else if (semanticName.caseInsensitiveEquals(toSlice("SV_VIEWPORTARRAYINDEX"))) - { - m_writer->emit(" [[viewport_array_index]]"); - } - else if (semanticName.caseInsensitiveEquals(toSlice("SV_Depth"))) - { - m_writer->emit(" [[depth(any)]]"); - } - else if (semanticName.caseInsensitiveEquals(toSlice("SV_DepthGreaterEqual"))) - { - m_writer->emit(" [[depth(greater)]]"); - } - else if (semanticName.caseInsensitiveEquals(toSlice("SV_DepthLessEqual"))) - { - m_writer->emit(" [[depth(less)]]"); - } - else if (semanticName.caseInsensitiveEquals(toSlice("SV_Coverage"))) - { - m_writer->emit(" [[sample_mask]]"); - } - else if (semanticName.caseInsensitiveEquals(toSlice("SV_StencilRef"))) - { - m_writer->emit(" [[stencil]]"); + m_writer->emit(" [["); + m_writer->emit(sysSemanticDecor->getSemantic()); + m_writer->emit("]]"); + return true; } - else + return false; +} + +void MetalSourceEmitter::_emitUserSemantic(UnownedStringSlice semanticName, IRIntegerValue semanticIndex) +{ + if (!semanticName.startsWithCaseInsensitive(toSlice("SV_"))) { m_writer->emit(" [[user("); - m_writer->emit(semanticName); + m_writer->emit(String(semanticName).toUpper()); if (semanticIndex != 0) { m_writer->emit("_"); @@ -834,6 +777,10 @@ void MetalSourceEmitter::emitSemanticsImpl(IRInst* inst, bool allowOffsets) if (inst->getOp() == kIROp_StructKey) { // Only emit [[attribute(n)]] on struct keys. + + if (maybeEmitSystemSemantic(inst)) + return; + bool hasSemanticFromLayout = false; if (auto varLayout = findVarLayout(inst)) { @@ -851,7 +798,7 @@ void MetalSourceEmitter::emitSemanticsImpl(IRInst* inst, bool allowOffsets) else if (auto semanticAttr = as<IRSemanticAttr>(attr)) { auto semanticName = String(semanticAttr->getName()).toUpper(); - _emitSystemSemantic(semanticAttr->getName(), semanticAttr->getIndex()); + _emitUserSemantic(semanticAttr->getName(), semanticAttr->getIndex()); hasSemanticFromLayout = true; } } @@ -861,7 +808,7 @@ void MetalSourceEmitter::emitSemanticsImpl(IRInst* inst, bool allowOffsets) { if (auto semanticDecor = inst->findDecoration<IRSemanticDecoration>()) { - _emitSystemSemantic(semanticDecor->getSemanticName(), semanticDecor->getSemanticIndex()); + _emitUserSemantic(semanticDecor->getSemanticName(), semanticDecor->getSemanticIndex()); } } } diff --git a/source/slang/slang-emit-metal.h b/source/slang/slang-emit-metal.h index 55ad3d4cb..8b014d604 100644 --- a/source/slang/slang-emit-metal.h +++ b/source/slang/slang-emit-metal.h @@ -76,7 +76,8 @@ protected: void _emitHLSLDecorationSingleInt(const char* name, IRFunc* entryPoint, IRIntLit* val); void _emitStageAccessSemantic(IRStageAccessDecoration* decoration, const char* name); - void _emitSystemSemantic(UnownedStringSlice semanticName, IRIntegerValue semanticIndex); + void _emitUserSemantic(UnownedStringSlice semanticName, IRIntegerValue semanticIndex); + bool maybeEmitSystemSemantic(IRInst* inst); }; } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 4d39eb978..37e3d2064 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -715,6 +715,8 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace) INST_RANGE(TargetSpecificDecoration, TargetDecoration, RequirePreludeDecoration) INST(GLSLOuterArrayDecoration, glslOuterArray, 1, 0) + INST(TargetSystemValueDecoration, TargetSystemValue, 2, 0) + INST(InterpolationModeDecoration, interpolationMode, 1, 0) INST(NameHintDecoration, nameHint, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 5670cad47..301604bf3 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -118,6 +118,19 @@ struct IRTargetDecoration : IRTargetSpecificDefinitionDecoration IR_LEAF_ISA(TargetDecoration) }; +struct IRTargetSystemValueDecoration : IRDecoration +{ + enum { kOp = kIROp_TargetSystemValueDecoration }; + IR_LEAF_ISA(TargetSystemValueDecoration) + + IRStringLit* getSemanticOperand() { return cast<IRStringLit>(getOperand(0)); } + + UnownedStringSlice getSemantic() + { + return getSemanticOperand()->getStringSlice(); + } +}; + struct IRTargetIntrinsicDecoration : IRTargetSpecificDefinitionDecoration { enum { kOp = kIROp_TargetIntrinsicDecoration }; @@ -4351,6 +4364,12 @@ public: void addHighLevelDeclDecoration(IRInst* value, Decl* decl); + IRDecoration* addTargetSystemValueDecoration(IRInst* value, UnownedStringSlice sysValName, UInt index = 0) + { + IRInst* operands[] = { getStringValue(sysValName), getIntValue(getIntType(), index)}; + return addDecoration(value, kIROp_TargetSystemValueDecoration, operands, SLANG_COUNT_OF(operands)); + } + // void addLayoutDecoration(IRInst* value, Layout* layout); void addLayoutDecoration(IRInst* value, IRLayout* layout); diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp index d4a234515..70f4cbd27 100644 --- a/source/slang/slang-ir-metal-legalize.cpp +++ b/source/slang/slang-ir-metal-legalize.cpp @@ -50,6 +50,12 @@ namespace Slang auto structType = as<IRStructType>(param->getDataType()); builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); auto varLayout = findVarLayout(param); + + // If `param` already has a semantic, we don't want to hoist its fields out. + if (varLayout->findSystemValueSemanticAttr() != nullptr || + param->findDecoration<IRSemanticDecoration>()) + continue; + IRStructTypeLayout* structTypeLayout = nullptr; if (varLayout) structTypeLayout = as<IRStructTypeLayout>(varLayout->getTypeLayout()); @@ -214,8 +220,168 @@ namespace Slang fixUpFuncType(func); } + struct MetalSystemValueInfo + { + String metalSystemValueName; + IRType* requiredType; + IRType* altRequiredType; + bool isUnsupported; + bool isSpecial; + }; - void ensureResultStructHasUserSemantic(IRStructType* structType, IRVarLayout* varLayout) + MetalSystemValueInfo getSystemValueInfo(IRBuilder& builder, String semanticName, UInt attrIndex) + { + SLANG_UNUSED(attrIndex); + + MetalSystemValueInfo result = {}; + + semanticName = semanticName.toLower(); + + if (semanticName == "sv_position") + { + result.metalSystemValueName = toSlice("position"); + result.requiredType = builder.getVectorType(builder.getBasicType(BaseType::Float), builder.getIntValue(builder.getIntType(), 4)); + } + else if (semanticName == "sv_clipdistance") + { + result.isSpecial = true; + } + else if (semanticName == "sv_culldistance") + { + result.isSpecial = true; + } + else if (semanticName == "sv_coverage") + { + result.metalSystemValueName = toSlice("sample_mask"); + result.requiredType = builder.getBasicType(BaseType::UInt); + } + else if (semanticName == "sv_innercoverage") + { + result.isSpecial = true; + + } + else if (semanticName == "sv_depth") + { + result.metalSystemValueName = toSlice("depth(any)"); + result.requiredType = builder.getBasicType(BaseType::Float); + } + else if (semanticName == "sv_depthgreaterequal") + { + result.metalSystemValueName = toSlice("depth(greater)"); + result.requiredType = builder.getBasicType(BaseType::Float); + } + else if (semanticName == "sv_depthlessequal") + { + result.metalSystemValueName = toSlice("depth(less)"); + result.requiredType = builder.getBasicType(BaseType::Float); + } + else if (semanticName == "sv_dispatchthreadid") + { + result.metalSystemValueName = toSlice("thread_position_in_grid"); + result.requiredType = builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3)); + } + else if (semanticName == "sv_domainlocation") + { + result.metalSystemValueName = toSlice("position_in_patch"); + result.requiredType = builder.getVectorType(builder.getBasicType(BaseType::Float), builder.getIntValue(builder.getIntType(), 3)); + result.altRequiredType = builder.getVectorType(builder.getBasicType(BaseType::Float), builder.getIntValue(builder.getIntType(), 2)); + } + else if (semanticName == "sv_groupid") + { + result.isSpecial = true; + } + else if (semanticName == "sv_groupindex") + { + result.isSpecial = true; + } + else if (semanticName == "sv_groupthreadid") + { + result.metalSystemValueName = toSlice("thread_position_in_threadgroup"); + result.requiredType = builder.getVectorType(builder.getBasicType(BaseType::UInt), builder.getIntValue(builder.getIntType(), 3)); + } + else if (semanticName == "sv_gsinstanceid") + { + // Metal does not have geometry shader, so this is invalid. + result.isUnsupported = true; + } + else if (semanticName == "sv_instanceid") + { + result.metalSystemValueName = toSlice("instance_id"); + result.requiredType = builder.getBasicType(BaseType::UInt); + } + else if (semanticName == "sv_isfrontface") + { + result.metalSystemValueName = toSlice("front_facing"); + result.requiredType = builder.getBasicType(BaseType::Bool); + } + else if (semanticName == "sv_outputcontrolpointid") + { + // In metal, a hull shader is just a compute shader. + // This needs to be handled separately, by lowering into an ordinary buffer. + } + else if (semanticName == "sv_pointsize") + { + result.metalSystemValueName = toSlice("point_size"); + result.requiredType = builder.getBasicType(BaseType::Float); + } + else if (semanticName == "sv_primitiveid") + { + result.metalSystemValueName = toSlice("patch_id"); + result.requiredType = builder.getBasicType(BaseType::UInt); + result.altRequiredType = builder.getBasicType(BaseType::UInt16); + } + else if (semanticName == "sv_rendertargetarrayindex") + { + result.metalSystemValueName = toSlice("render_target_array_index"); + result.requiredType = builder.getBasicType(BaseType::UInt); + result.altRequiredType = builder.getBasicType(BaseType::UInt16); + } + else if (semanticName == "sv_sampleindex") + { + result.metalSystemValueName = toSlice("sample_id"); + result.requiredType = builder.getBasicType(BaseType::UInt); + } + else if (semanticName == "sv_stencilref") + { + result.metalSystemValueName = toSlice("stencil"); + result.requiredType = builder.getBasicType(BaseType::UInt); + } + else if (semanticName == "sv_tessfactor") + { + // Tessellation factor outputs should be lowered into a write into a normal buffer. + } + else if (semanticName == "sv_vertexid") + { + result.metalSystemValueName = toSlice("vertex_id"); + result.requiredType = builder.getBasicType(BaseType::UInt); + } + else if (semanticName == "sv_viewid") + { + result.isUnsupported = true; + } + else if (semanticName == "sv_viewportarrayindex") + { + result.metalSystemValueName = toSlice("viewport_array_index"); + result.requiredType = builder.getBasicType(BaseType::UInt); + result.altRequiredType = builder.getBasicType(BaseType::UInt16); + } + else if (semanticName == "sv_target") + { + result.metalSystemValueName = (StringBuilder() << "color(" << String(attrIndex) << ")").produceString(); + } + else + { + result.isUnsupported = true; + } + return result; + } + + void reportUnsupportedSystemAttribute(DiagnosticSink* sink, IRInst* param, String semanticName) + { + sink->diagnose(param->sourceLoc, Diagnostics::systemValueAttributeNotSupported, semanticName); + } + + void ensureResultStructHasUserSemantic(DiagnosticSink* sink, IRStructType* structType, IRVarLayout* varLayout) { // Ensure each field in an output struct type has either a system semantic or a user semantic, // so that signature matching can happen correctly. @@ -225,8 +391,21 @@ namespace Slang for (auto field : structType->getFields()) { auto key = field->getKey(); - if (key->findDecoration<IRSemanticDecoration>()) + if (auto semanticDecor = key->findDecoration<IRSemanticDecoration>()) { + if (semanticDecor->getSemanticName().startsWithCaseInsensitive(toSlice("sv_"))) + { + auto sysValInfo = getSystemValueInfo(builder, semanticDecor->getSemanticName(), semanticDecor->getSemanticIndex()); + if (sysValInfo.isUnsupported || sysValInfo.isSpecial) + { + reportUnsupportedSystemAttribute(sink, field, semanticDecor->getSemanticName()); + } + else + { + builder.addTargetSystemValueDecoration(key, sysValInfo.metalSystemValueName.getUnownedSlice()); + semanticDecor->removeAndDeallocate(); + } + } index++; continue; } @@ -245,7 +424,7 @@ namespace Slang } - void wrapReturnValueInStruct(EntryPointInfo entryPoint) + void wrapReturnValueInStruct(DiagnosticSink* sink, EntryPointInfo entryPoint) { // Wrap return value into a struct if it is not already a struct. // For example, given this entry point: @@ -275,7 +454,7 @@ namespace Slang // If return type is already a struct, just make sure every field has a semantic. if (auto returnStructType = as<IRStructType>(returnType)) { - ensureResultStructHasUserSemantic(returnStructType, resultLayout); + ensureResultStructHasUserSemantic(sink, returnStructType, resultLayout); return; } @@ -288,13 +467,14 @@ namespace Slang auto key = builder.createStructKey(); builder.addNameHintDecoration(key, toSlice("output")); builder.addLayoutDecoration(key, resultLayout); + builder.addTargetSystemValueDecoration(key, toSlice("color(0)")); builder.createStructField(structType, key, returnType); IRStructTypeLayout::Builder structTypeLayoutBuilder(&builder); structTypeLayoutBuilder.addField(key, resultLayout); auto typeLayout = structTypeLayoutBuilder.build(); IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout); auto varLayout = varLayoutBuilder.build(); - ensureResultStructHasUserSemantic(structType, varLayout); + ensureResultStructHasUserSemantic(sink, structType, varLayout); for (auto block : func->getBlocks()) { @@ -401,13 +581,146 @@ namespace Slang }); } - void legalizeEntryPointForMetal(EntryPointInfo entryPoint, DiagnosticSink* sink) + IRInst* tryConvertValue(IRBuilder& builder, IRInst* val, IRType* toType) + { + auto fromType = val->getFullType(); + if (auto fromVector = as<IRVectorType>(fromType)) + { + if (auto toVector = as<IRVectorType>(toType)) + { + if (fromVector->getElementCount() != toVector->getElementCount()) + { + fromType = builder.getVectorType(fromVector->getElementType(), toVector->getElementCount()); + val = builder.emitVectorReshape(fromType, val); + } + } + else if (as<IRBasicType>(toType)) + { + UInt index = 0; + val = builder.emitSwizzle(fromVector->getElementType(), val, 1, &index); + if (toType->getOp() == kIROp_VoidType) + return nullptr; + } + } + else if (auto fromBasicType = as<IRBasicType>(fromType)) + { + if (fromBasicType->getOp() == kIROp_VoidType) + return nullptr; + if (!as<IRBasicType>(toType)) + return nullptr; + if (toType->getOp() == kIROp_VoidType) + return nullptr; + } + else + { + return nullptr; + } + return builder.emitCast(toType, val); + } + + void legalizeSystemValueParameters(EntryPointInfo entryPoint, DiagnosticSink* sink) { SLANG_UNUSED(sink); + struct SystemValLegalizationWorkItem + { + IRParam* param; + String attrName; + UInt attrIndex; + }; + List<SystemValLegalizationWorkItem> systemValWorkItems; + List<SystemValLegalizationWorkItem> workList; + + IRBuilder builder(entryPoint.entryPointFunc); + List<IRParam*> params; + + for (auto param : entryPoint.entryPointFunc->getParams()) + { + if (auto semanticDecoration = param->findDecoration<IRSemanticDecoration>()) + { + if (semanticDecoration->getSemanticName().startsWithCaseInsensitive(toSlice("sv_"))) + { + systemValWorkItems.add({ param, String(semanticDecoration->getSemanticName()).toLower(), (UInt)semanticDecoration->getSemanticIndex() }); + continue; + } + } + + auto layoutDecor = param->findDecoration<IRLayoutDecoration>(); + if (!layoutDecor) + continue; + auto sysValAttr = layoutDecor->findAttr<IRSystemValueSemanticAttr>(); + if (!sysValAttr) + continue; + auto semanticName = String(sysValAttr->getName()); + auto sysAttrIndex = sysValAttr->getIndex(); + systemValWorkItems.add({ param, semanticName, sysAttrIndex }); + } + for (auto workItem : systemValWorkItems) + { + auto param = workItem.param; + auto semanticName = workItem.attrName; + auto sysAttrIndex = workItem.attrIndex; + + auto info = getSystemValueInfo(builder, semanticName, sysAttrIndex); + if (info.isSpecial) + { + if (semanticName == "sv_innercoverage") + { + // Metal does not support conservative rasterization, so this is always false. + auto val = builder.getBoolValue(false); + param->replaceUsesWith(val); + param->removeAndDeallocate(); + } + else + { + // Process special cases after trivial cases. + workList.add(workItem); + } + } + if (info.isUnsupported) + { + reportUnsupportedSystemAttribute(sink, param, semanticName); + continue; + } + if (!info.requiredType) + continue; + + builder.addTargetSystemValueDecoration(param, info.metalSystemValueName.getUnownedSlice()); + + // If the required type is different from the actual type, we need to insert a conversion. + if (info.requiredType != param->getFullType() && info.altRequiredType != param->getFullType()) + { + auto targetType = param->getFullType(); + builder.setInsertBefore(entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst()); + param->setFullType(info.requiredType); + List<IRUse*> uses; + for (auto use = param->firstUse; use; use = use->nextUse) + uses.add(use); + auto convertedValue = tryConvertValue(builder, param, targetType); + copyNameHintAndDebugDecorations(convertedValue, param); + if (!convertedValue) + { + // If we can't convert the value, report an error. + StringBuilder typeNameSB; + getTypeNameHint(typeNameSB, info.requiredType); + sink->diagnose(param->sourceLoc, Diagnostics::systemValueTypeIncompatible, semanticName, typeNameSB.produceString()); + } + else + { + for (auto use : uses) + builder.replaceOperand(use, convertedValue); + } + } + } + fixUpFuncType(entryPoint.entryPointFunc); + } + + void legalizeEntryPointForMetal(EntryPointInfo entryPoint, DiagnosticSink* sink) + { hoistEntryPointParameterFromStruct(entryPoint); packStageInParameters(entryPoint); - wrapReturnValueInStruct(entryPoint); + legalizeSystemValueParameters(entryPoint, sink); + wrapReturnValueInStruct(sink, entryPoint); legalizeMeshEntryPoint(entryPoint); legalizeDispatchMeshPayloadForMetal(entryPoint); } diff --git a/tests/metal/stage-in-2.slang b/tests/metal/stage-in-2.slang new file mode 100644 index 000000000..2b1e61306 --- /dev/null +++ b/tests/metal/stage-in-2.slang @@ -0,0 +1,70 @@ +//TEST:SIMPLE(filecheck=CHECK): -target metal +//TEST:SIMPLE(filecheck=CHECK-ASM): -target metallib + +// CHECK-ASM: define {{.*}} @vertexMain +// CHECK-ASM: define {{.*}} @fragmentMain + +// Check that we don't flatten stage-input parameters that have user semantics. + +// CHECK: struct pixelInput +// CHECK-NEXT: { +// CHECK-NEXT: CoarseVertex{{.*}} coarseVertex{{.*}} {{\[\[}}user(COARSEVERTEX){{\]\]}}; + +// Uniform data to be passed from application -> shader. +cbuffer Uniforms +{ + float4x4 modelViewProjection; +} + +// Per-vertex attributes to be assembled from bound vertex buffers. +struct AssembledVertex +{ + float3 position : POSITION; + float3 color : COLOR; +}; + +// Output of the vertex shader, and input to the fragment shader. +struct CoarseVertex +{ + float3 color; +}; + +// Output of the fragment shader +struct Fragment +{ + float4 color; +}; + +// Vertex Shader + +struct VertexStageOutput +{ + CoarseVertex coarseVertex : CoarseVertex; + float4 sv_position : SV_Position; +}; + +[shader("vertex")] +VertexStageOutput vertexMain( + AssembledVertex assembledVertex) +{ + VertexStageOutput output; + + float3 position = assembledVertex.position; + float3 color = assembledVertex.color; + + output.coarseVertex.color = color; + output.sv_position = mul(modelViewProjection, float4(position, 1.0)); + + return output; +} + +// Fragment Shader + +[shader("fragment")] +float4 fragmentMain( + CoarseVertex coarseVertex : CoarseVertex) : SV_Target +{ + float3 color = coarseVertex.color; + + return float4(color, 1.0); +}
\ No newline at end of file diff --git a/tests/metal/stage-in.slang b/tests/metal/stage-in.slang index 31d224072..ee586847e 100644 --- a/tests/metal/stage-in.slang +++ b/tests/metal/stage-in.slang @@ -4,8 +4,8 @@ // CHECK: struct VOut{{.*}} // CHECK-NEXT:{ // CHECK-NEXT: float4 position{{.*}} {{\[\[}}position{{\]\]}}; -// CHECK-NEXT: float4 vertexColor{{.*}} {{\[\[}}user(_slang_attr){{\]\]}}; -// CHECK-NEXT: float2 vertexUV{{.*}} {{\[\[}}user(_slang_attr_1){{\]\]}}; +// CHECK-NEXT: float4 vertexColor{{.*}} {{\[\[}}user(_SLANG_ATTR){{\]\]}}; +// CHECK-NEXT: float2 vertexUV{{.*}} {{\[\[}}user(_SLANG_ATTR_1){{\]\]}}; // CHECK-NEXT: float3 vertexNormal{{.*}} {{\[\[}}user(NORMAL){{\]\]}}; // CHECK-NEXT:}; @@ -24,8 +24,8 @@ // CHECK: struct pixelInput{{.*}} // CHECK-NEXT:{ -// CHECK-NEXT: float4 vertexColor{{.*}} {{\[\[}}user(_slang_attr){{\]\]}}; -// CHECK-NEXT: float2 vertexUV{{.*}} {{\[\[}}user(_slang_attr_1){{\]\]}}; +// CHECK-NEXT: float4 vertexColor{{.*}} {{\[\[}}user(_SLANG_ATTR){{\]\]}}; +// CHECK-NEXT: float2 vertexUV{{.*}} {{\[\[}}user(_SLANG_ATTR_1){{\]\]}}; // CHECK-NEXT: float3 vertexNormal{{.*}} {{\[\[}}user(NORMAL){{\]\]}}; // CHECK-NEXT:}; diff --git a/tests/metal/system-val-conversion.slang b/tests/metal/system-val-conversion.slang new file mode 100644 index 000000000..5a208086d --- /dev/null +++ b/tests/metal/system-val-conversion.slang @@ -0,0 +1,29 @@ +//TEST:SIMPLE(filecheck=CHECK-ASM): -target metallib +//TEST:SIMPLE(filecheck=CHECK): -target metal + +// Test that we always emit correct type for system value and insert conversion logic +// if the declared type of the SV is different from the spec-defined type. + +uniform RWStructuredBuffer<float> outputBuffer; + +RWByteAddressBuffer buffer; + +// CHECK-ASM: define void @main_kernel + +struct TestStruct +{ + uint8_t a; + float16_t h; + float b; + float4 c; + float4x3 d; +} + +// CHECK: void main_kernel(uint3 tid{{.*}} +// CHECK: int tid{{.*}} = int(tid{{.*}}.x); + +[numthreads(1,1,1)] +void main_kernel(int tid: SV_DispatchThreadID) +{ + buffer.Store(128, buffer.Load<TestStruct>(tid)); +} |
