From 4485cf3eaf142cfd5f8470e86739acc67d4e12ea Mon Sep 17 00:00:00 2001 From: Yong He Date: Thu, 6 Mar 2025 14:26:34 -0800 Subject: Update SPIRV-Tools and fix new validation errors. (#6511) * Update SPIRV-Tools and fix new validation errors. * Implement pointers for glsl target. * Reworked packStorage/unpackStorage code gen to operate on pointers rather than values. --- source/slang/hlsl.meta.slang | 76 +--- source/slang/slang-ast-type.h | 7 - source/slang/slang-emit-c-like.cpp | 7 +- source/slang/slang-emit-glsl.cpp | 116 +++++- source/slang/slang-emit-spirv-ops.h | 14 + source/slang/slang-emit-spirv.cpp | 87 +++-- source/slang/slang-emit.cpp | 5 +- source/slang/slang-ir-inst-defs.h | 7 +- source/slang/slang-ir-insts.h | 26 +- source/slang/slang-ir-layout.cpp | 5 +- .../slang/slang-ir-lower-buffer-element-type.cpp | 431 +++++++++++++-------- source/slang/slang-ir-spirv-legalize.cpp | 135 ++++++- source/slang/slang-ir-util.cpp | 5 - source/slang/slang-ir.cpp | 40 +- source/slang/slang-ir.h | 6 - source/slang/slang-lower-to-ir.cpp | 11 +- 16 files changed, 656 insertions(+), 322 deletions(-) (limited to 'source') diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index d2abfc7fe..c26a7613b 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -21609,90 +21609,36 @@ extension uint64_t } } -__generic -__intrinsic_type($(kIROp_HLSLConstBufferPointerType)) -__glsl_extension(GL_EXT_buffer_reference) -__magic_type(ConstBufferPointerType) -[require(glsl_spirv, bufferreference)] -struct ConstBufferPointer +struct ConstBufferPointer { - __glsl_version(450) - __glsl_extension(GL_EXT_buffer_reference) - [__NoSideEffect] - T get() - { - __target_switch - { - case glsl: - __intrinsic_asm "$0._data"; - case spirv: - return spirv_asm { - result:$$T = OpLoad $this Aligned !Alignment; - }; - } - } - + T *_ptr; + [ForceInline] T get() { return loadAligned(_ptr); } __subscript(int index) -> T { [ForceInline] - get {return ConstBufferPointer.fromUInt(toUInt() + __naturalStrideOf() * index).get(); } + get { return _ptr[index]; } } - __glsl_version(450) - __glsl_extension(GL_EXT_shader_explicit_arithmetic_types_int64) - __glsl_extension(GL_EXT_buffer_reference) - [require(glsl_spirv, bufferreference_int64)] + [ForceInline] T* getPtr() { return _ptr; } + + [ForceInline] static ConstBufferPointer fromUInt(uint64_t val) { - __target_switch - { - case glsl: - __intrinsic_asm "$TR($0)"; - case spirv: - return spirv_asm { - result:$$ConstBufferPointer = OpConvertUToPtr $val; - }; - } + return {(T*)val}; } - __glsl_version(450) - __glsl_extension(GL_EXT_shader_explicit_arithmetic_types_int64) - __glsl_extension(GL_EXT_buffer_reference) - [require(glsl_spirv, bufferreference_int64)] + [ForceInline] uint64_t toUInt() { - __target_switch - { - case glsl: - __intrinsic_asm "uint64_t($0)"; - case spirv: - return spirv_asm { - result:$$uint64_t = OpConvertPtrToU $this; - }; - } + return (uint64_t)_ptr; } - __glsl_version(450) - __glsl_extension(GL_EXT_shader_explicit_arithmetic_types_int64) - __glsl_extension(GL_EXT_buffer_reference) - [__NoSideEffect] [ForceInline] - [require(glsl_spirv, bufferreference_int64)] bool isValid() { - __target_switch - { - case glsl: - __intrinsic_asm "(uint64_t($0) != 0)"; - case spirv: - uint64_t zero = 0ULL; - return spirv_asm { - %ptrval:$$uint64_t = OpConvertPtrToU $this; - result:$$bool = OpINotEqual %ptrval $zero; - }; - } + return _ptr != nullptr; } } diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index f60c0485e..7393092f9 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -632,13 +632,6 @@ class PtrType : public PtrTypeBase void _toTextOverride(StringBuilder& out); }; -// A GPU pointer type into global memory. - -class ConstBufferPointerType : public PtrTypeBase -{ - SLANG_AST_CLASS(ConstBufferPointerType) -}; - /// A pointer-like type used to represent a parameter "direction" class ParamDirectionType : public PtrTypeBase { diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index ff40d5b28..d337e09a0 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -1236,9 +1236,9 @@ String CLikeSourceEmitter::generateName(IRInst* inst) return linkageDecoration->getMangledName(); } - switch (inst->getOp()) + if (auto ptrType = as(inst)) { - case kIROp_HLSLConstBufferPointerType: + if (ptrType->getAddressSpace() == AddressSpace::UserPointer) { StringBuilder sb; sb << "BufferPointer_"; @@ -1246,9 +1246,8 @@ String CLikeSourceEmitter::generateName(IRInst* inst) sb << "_" << Int32(getID(inst)); return sb.produceString(); } - default: - break; } + // Otherwise fall back to a construct temporary name // for the instruction. StringBuilder sb; diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index fca5a8933..86699df48 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -2036,13 +2036,22 @@ bool GLSLSourceEmitter::_tryEmitBitBinOp( return true; } -void GLSLSourceEmitter::emitBufferPointerTypeDefinition(IRInst* ptrType) +void GLSLSourceEmitter::emitBufferPointerTypeDefinition(IRInst* type) { + auto ptrType = as(type); + if (!ptrType) + return; + if (ptrType->getAddressSpace() != AddressSpace::UserPointer) + return; _requireGLSLExtension(UnownedStringSlice("GL_EXT_buffer_reference")); - auto constPtrType = as(ptrType); auto ptrTypeName = getName(ptrType); - auto alignment = getIntVal(constPtrType->getBaseAlignment()); + IRSizeAndAlignment sizeAlignment; + getNaturalSizeAndAlignment( + m_codeGenContext->getTargetProgram()->getOptionSet(), + ptrType->getValueType(), + &sizeAlignment); + auto alignment = sizeAlignment.alignment; m_writer->emit("layout(buffer_reference, std430, buffer_reference_align = "); m_writer->emitInt64(alignment); m_writer->emit(") readonly buffer "); @@ -2050,7 +2059,7 @@ void GLSLSourceEmitter::emitBufferPointerTypeDefinition(IRInst* ptrType) m_writer->emit("\n"); m_writer->emit("{\n"); m_writer->indent(); - emitType((IRType*)constPtrType->getValueType(), "_data"); + emitType((IRType*)ptrType->getValueType(), "_data"); m_writer->emit(";\n"); m_writer->dedent(); m_writer->emit("};\n"); @@ -2079,7 +2088,7 @@ void GLSLSourceEmitter::emitGlobalInstImpl(IRInst* inst) { switch (inst->getOp()) { - case kIROp_HLSLConstBufferPointerType: + case kIROp_PtrType: emitBufferPointerTypeDefinition(inst); break; // No need to use structs which are just taking part in a SSBO declaration @@ -2102,6 +2111,43 @@ bool GLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu m_writer->emit("barrier();\n"); return true; } + case kIROp_Load: + { + auto addr = inst->getOperand(0); + auto ptrType = as(addr->getDataType()); + if (!ptrType) + return false; + if (ptrType->getAddressSpace() == AddressSpace::UserPointer) + { + auto prec = getInfo(EmitOp::Postfix); + EmitOpInfo outerPrec = inOuterPrec; + bool needClose = maybeEmitParens(outerPrec, prec); + emitOperand(inst->getOperand(0), prec); + m_writer->emit("._data"); + maybeCloseParens(needClose); + return true; + } + return false; + } + case kIROp_FieldAddress: + { + auto addr = inst->getOperand(0); + auto ptrType = as(addr->getDataType()); + if (!ptrType) + return false; + if (ptrType->getAddressSpace() == AddressSpace::UserPointer) + { + auto prec = getInfo(EmitOp::Postfix); + EmitOpInfo outerPrec = inOuterPrec; + bool needClose = maybeEmitParens(outerPrec, prec); + emitOperand(inst->getOperand(0), prec); + m_writer->emit("._data."); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + maybeCloseParens(needClose); + return true; + } + return false; + } case kIROp_MakeVectorFromScalar: case kIROp_MatrixReshape: { @@ -2372,10 +2418,57 @@ bool GLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu return true; } + if (as(left->getDataType()) || as(right->getDataType())) + { + _requireGLSLExtension( + UnownedStringSlice("GL_EXT_shader_explicit_arithmetic_types_int64")); + // For pointers we need to cast to uint before comparing + auto getOperatorString = [](IROp op) -> const char* + { + switch (op) + { + case kIROp_Eql: + return "=="; + case kIROp_Neq: + return "!="; + case kIROp_Greater: + return ">"; + case kIROp_Less: + return "<"; + case kIROp_Geq: + return ">="; + case kIROp_Leq: + return "<="; + default: + return nullptr; + } + }; + EmitOpInfo outerPrec = inOuterPrec; + auto prec = getInfo(EmitOp::General); + bool needClose = maybeEmitParens(outerPrec, prec); + + m_writer->emit("uint64_t("); + emitOperand(left, getInfo(EmitOp::General)); + m_writer->emit(")"); + m_writer->emit(" "); + m_writer->emit(getOperatorString(inst->getOp())); + m_writer->emit(" "); + m_writer->emit("uint64_t("); + emitOperand(right, getInfo(EmitOp::General)); + m_writer->emit(")"); + + maybeCloseParens(needClose); + return true; + } // Use the default break; } + case kIROp_GetOffsetPtr: + { + _requireGLSLExtension(UnownedStringSlice("GL_EXT_buffer_reference2")); + return false; + } case kIROp_FRem: { IRInst* left = inst->getOperand(0); @@ -2560,6 +2653,16 @@ bool GLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu m_writer->emit(")"); return true; } + case kIROp_PtrLit: + { + auto ptrType = as(inst->getDataType()); + if (ptrType) + { + m_writer->emit("0"); + return true; + } + break; + } default: break; } @@ -3204,10 +3307,9 @@ void GLSLSourceEmitter::emitSimpleTypeImpl(IRType* type) return; } case kIROp_StructType: - case kIROp_HLSLConstBufferPointerType: + case kIROp_PtrType: m_writer->emit(getName(type)); return; - case kIROp_VectorType: { auto vecType = (IRVectorType*)type; diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h index 8c5316f51..880f6b083 100644 --- a/source/slang/slang-emit-spirv-ops.h +++ b/source/slang/slang-emit-spirv-ops.h @@ -1409,6 +1409,20 @@ SpvInst* emitOpBitcast( return emitInst(parent, inst, SpvOpBitcast, idResultType, kResultID, operand); } +// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCopyLogical +template +SpvInst* emitOpCopyLogical( + SpvInstParent* parent, + IRInst* inst, + const T1& idResultType, + const T2& operand) +{ + static_assert(isSingular); + static_assert(isSingular); + return emitInst(parent, inst, SpvOpCopyLogical, idResultType, kResultID, operand); +} + + // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSNegate template SpvInst* emitOpSNegate( diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index c7e222247..7a6d4aa38 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -491,7 +491,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex /// The next destination `` to allocate. SpvWord m_nextID = 1; - OrderedHashSet m_forwardDeclaredPointers; + OrderedDictionary m_forwardDeclaredPointers; SpvInst* m_nullDwarfExpr = nullptr; @@ -1437,6 +1437,20 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex m_addressingMode = SpvAddressingModelPhysicalStorageBuffer64; } + bool shouldEmitArrayStride(IRInst* elementType) + { + for (auto decor : elementType->getDecorations()) + { + switch (decor->getOp()) + { + case kIROp_SPIRVBufferBlockDecoration: + case kIROp_SPIRVBlockDecoration: + return false; + } + } + return true; + } + // Next, let's look at emitting some of the instructions // that can occur at global scope. @@ -1554,7 +1568,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex { // After everything has been emitted, we will move the pointer definition to the // end of the Types & Constants section. - if (m_forwardDeclaredPointers.add(ptrType)) + if (m_forwardDeclaredPointers.addIfNotExists( + resultSpvType, + (IRPtrTypeBase*)inst)) emitOpTypeForwardPointer(resultSpvType, storageClass); } if (storageClass == SpvStorageClassPhysicalStorageBuffer) @@ -1654,42 +1670,22 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_ArrayType: case kIROp_UnsizedArrayType: { - const auto elementType = static_cast(inst)->getElementType(); + auto irArrayType = static_cast(inst); + const auto elementType = irArrayType->getElementType(); const auto arrayType = inst->getOp() == kIROp_ArrayType - ? emitOpTypeArray( - inst, - elementType, - static_cast(inst)->getElementCount()) + ? emitOpTypeArray(inst, elementType, irArrayType->getElementCount()) : emitOpTypeRuntimeArray(inst, elementType); - auto strideInst = as(inst)->getArrayStride(); - int stride = 0; - if (strideInst) - { - stride = (int)getIntVal(strideInst); - } - else - { - IRSizeAndAlignment sizeAndAlignment; - getNaturalSizeAndAlignment( - m_targetProgram->getOptionSet(), - elementType, - &sizeAndAlignment); - stride = (int)sizeAndAlignment.getStride(); - } - - // Avoid validation error: Array containing a Block or BufferBlock must not be - // decorated with ArrayStride - if (!elementType->findDecorationImpl(kIROp_SPIRVBufferBlockDecoration) && - !elementType->findDecorationImpl(kIROp_SPIRVBlockDecoration)) + auto strideInst = irArrayType->getArrayStride(); + if (strideInst && shouldEmitArrayStride(irArrayType->getElementType())) { + int stride = (int)getIntVal(strideInst); emitOpDecorateArrayStride( getSection(SpvLogicalSectionID::Annotations), nullptr, arrayType, SpvLiteralInteger::from32(stride)); } - return arrayType; } case kIROp_AtomicType: @@ -1727,13 +1723,6 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex requireSPIRVCapability(SpvCapabilityShaderInvocationReorderNV); return emitOpTypeHitObject(inst); - case kIROp_HLSLConstBufferPointerType: - requirePhysicalStorageAddressing(); - return emitOpTypePointer( - inst, - SpvStorageClassPhysicalStorageBuffer, - inst->getOperand(0)); - case kIROp_FuncType: // > OpTypeFunction // @@ -4946,6 +4935,21 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex } } + bool isPhysicalCompositeType(IRType* type) + { + for (auto decor : type->getDecorations()) + { + switch (decor->getOp()) + { + case kIROp_PhysicalTypeDecoration: + case kIROp_SPIRVBlockDecoration: + case kIROp_SPIRVBufferBlockDecoration: + return true; + } + } + return false; + } + void emitLayoutDecorations(IRStructType* structType, SpvWord spvStructID) { /***** @@ -4974,6 +4978,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex layoutRuleName = layout->getLayoutName(); } int32_t id = 0; + bool isPhysicalType = isPhysicalCompositeType(structType); for (auto field : structType->getFields()) { for (auto decor : field->getKey()->getDecorations()) @@ -5054,6 +5059,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex } } + if (!isPhysicalType) + continue; + + // Emit explicit struct field layout decorations if the struct is physical. IRIntegerValue offset = 0; if (auto offsetDecor = field->getKey()->findDecoration()) { @@ -5084,8 +5093,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex if (matrixType) { // SPIRV sepc on MatrixStride: - // Applies only to a member of a structure type.Only valid on a - // matrix or array whose most basic element is a matrix.Matrix + // Applies only to a member of a structure type. Only valid on a + // matrix or array whose most basic element is a matrix. Matrix // Stride is an unsigned 32 - bit integer specifying the stride // of the rows in a RowMajor - decorated matrix or columns in a // ColMajor - decorated matrix. @@ -8373,11 +8382,11 @@ SlangResult emitSPIRVFromIR( for (auto ptrType : fwdPointers) { - auto spvPtrType = context.m_mapIRInstToSpvInst[ptrType]; + auto spvPtrType = ptrType.key; // When we emit a pointee type, we may introduce new // forward-declared pointer types, so we need to // keep iterating until we have emitted all of them. - context.ensureInst(ptrType->getValueType()); + context.ensureInst(ptrType.value->getValueType()); auto parent = spvPtrType->parent; spvPtrType->removeFromParent(); parent->addInst(spvPtrType); diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 94ea66d71..4fb33ccc2 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1644,10 +1644,11 @@ Result linkAndOptimizeIR( bufferElementTypeLoweringOptions.use16ByteArrayElementForConstantBuffer = isWGPUTarget(targetRequest); lowerBufferElementTypeToStorageType(targetProgram, irModule, bufferElementTypeLoweringOptions); + performForceInlining(irModule); // Rewrite functions that return arrays to return them via `out` parameter, // since our target languages doesn't allow returning arrays. - if (!isMetalTarget(targetRequest)) + if (!isMetalTarget(targetRequest) && !isSPIRV(target)) legalizeArrayReturnType(irModule); if (isKhronosTarget(targetRequest) || target == CodeGenTarget::HLSL) @@ -1669,8 +1670,8 @@ Result linkAndOptimizeIR( if (emitSpirvDirectly) { performIntrinsicFunctionInlining(irModule); - eliminateDeadCode(irModule, deadCodeEliminationOptions); } + eliminateMultiLevelBreak(irModule); if (!fastIRSimplificationOptions.minimalOptimization) diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 714ba146d..d0c0b4b31 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -126,9 +126,9 @@ INST(Nop, nop, 0, 0) INST(OutType, Out, 1, HOISTABLE) INST(InOutType, InOut, 1, HOISTABLE) INST_RANGE(OutTypeBase, OutType, InOutType) - INST(HLSLConstBufferPointerType, ConstBufferPointerType, 2, HOISTABLE) INST_RANGE(PtrTypeBase, PtrType, InOutType) + // A ComPtr type is treated as a opaque type that represents a reference-counted handle to a COM object. INST(ComPtrType, ComPtr, 1, HOISTABLE) // A NativePtr type represents a native pointer to a managed resource. @@ -812,6 +812,11 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace) INST(InterpolationModeDecoration, interpolationMode, 1, 0) INST(NameHintDecoration, nameHint, 1, 0) + INST(PhysicalTypeDecoration, PhysicalType, 1, 0) + + // Mark an address instruction as aligned to a specific byte boundary. + INST(AlignedAddressDecoration, AlignedAddressDecoration, 1, 0) + // Marks a type as being used as binary interface (e.g. shader parameters). // This prevents the legalizeEmptyType() pass from eliminating it on C++/CUDA targets. INST(BinaryInterfaceTypeDecoration, BinaryInterfaceType, 0, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 5231592ca..a8a96f230 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -168,6 +168,12 @@ struct IRIntrinsicOpDecoration : IRDecoration IROp getIntrinsicOp() { return (IROp)getIntrinsicOpOperand()->getValue(); } }; +struct IRAlignedAddressDecoration : IRDecoration +{ + IR_LEAF_ISA(AlignedAddressDecoration) + IRInst* getAlignment() { return getOperand(0); } +}; + struct IRGLSLOuterArrayDecoration : IRDecoration { enum @@ -879,6 +885,8 @@ IR_SIMPLE_DECORATION(ForceInlineDecoration) IR_SIMPLE_DECORATION(ForceUnrollDecoration) +IR_SIMPLE_DECORATION(PhysicalTypeDecoration) + struct IRSizeAndAlignmentDecoration : IRDecoration { IR_LEAF_ISA(SizeAndAlignmentDecoration) @@ -3788,7 +3796,11 @@ public: /// Get a 'SPIRV literal' IRSPIRVLiteralType* getSPIRVLiteralType(IRType* type); - IRArrayTypeBase* getArrayTypeBase(IROp op, IRType* elementType, IRInst* elementCount); + IRArrayTypeBase* getArrayTypeBase( + IROp op, + IRType* elementType, + IRInst* elementCount, + IRInst* stride = nullptr); IRArrayType* getArrayType(IRType* elementType, IRInst* elementCount); @@ -4324,7 +4336,7 @@ public: IRVar* emitVar(IRType* type, AddressSpace addressSpace); IRInst* emitLoad(IRType* type, IRInst* ptr); - + IRInst* emitLoad(IRType* type, IRInst* ptr, IRInst* align); IRInst* emitLoad(IRInst* ptr); IRInst* emitLoadReverseGradient(IRType* type, IRInst* diffValue); @@ -4333,6 +4345,7 @@ public: IRInst* emitDiffParamRef(IRType* type, IRInst* param); IRInst* emitStore(IRInst* dstPtr, IRInst* srcVal); + IRInst* emitStore(IRInst* dstPtr, IRInst* srcVal, IRInst* align); IRInst* emitAtomicStore(IRInst* dstPtr, IRInst* srcVal, IRInst* memoryOrder); @@ -4725,6 +4738,15 @@ public: IRVarLayout* getVarLayout(List const& operands); IREntryPointLayout* getEntryPointLayout(IRVarLayout* paramsLayout, IRVarLayout* resultLayout); + void addPhysicalTypeDecoration(IRInst* value) + { + addDecoration(value, kIROp_PhysicalTypeDecoration); + } + + void addAlignedAddressDecoration(IRInst* value, IRInst* alignment) + { + addDecoration(value, kIROp_AlignedAddressDecoration, alignment); + } void addNameHintDecoration(IRInst* value, IRStringLit* name) { diff --git a/source/slang/slang-ir-layout.cpp b/source/slang/slang-ir-layout.cpp index 558877aaf..c60b39248 100644 --- a/source/slang/slang-ir-layout.cpp +++ b/source/slang/slang-ir-layout.cpp @@ -186,8 +186,8 @@ static Result _calcSizeAndAlignment( builder.getIntValue(intType, (IRIntegerValue)rules->ruleName), builder.getIntValue(intType, fieldOffset)); } - - structLayout.size += fieldTypeLayout.size; + if (!seenFinalUnsizedArrayField) + structLayout.size += fieldTypeLayout.size; offset = structLayout.size; if (as(field->getFieldType()) || as(field->getFieldType()) || @@ -340,7 +340,6 @@ static Result _calcSizeAndAlignment( case kIROp_NativePtrType: case kIROp_ComPtrType: case kIROp_NativeStringType: - case kIROp_HLSLConstBufferPointerType: case kIROp_RaytracingAccelerationStructureType: { *outSizeAndAlignment = IRSizeAndAlignment(kPointerSize, kPointerSize); diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index 74e84f1ee..6f0e22a57 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -58,14 +58,39 @@ struct LoweredElementTypeContext this->op = irop; return *this; } - IRInst* apply(IRBuilder& builder, IRType* resultType, IRInst* operand) + IRInst* apply(IRBuilder& builder, IRType* resultType, IRInst* operandAddr) { if (!*this) - return operand; + return builder.emitLoad(operandAddr); if (kind == ConversionMethodKind::Func) - return builder.emitCallInst(resultType, func, 1, &operand); + return builder.emitCallInst(resultType, func, 1, &operandAddr); else - return builder.emitIntrinsicInst(resultType, op, 1, &operand); + { + auto val = builder.emitLoad(operandAddr); + return builder.emitIntrinsicInst(resultType, op, 1, &val); + } + } + void applyDestinationDriven(IRBuilder& builder, IRInst* dest, IRInst* operand) + { + if (!*this) + { + builder.emitStore(dest, operand); + return; + } + if (kind == ConversionMethodKind::Func) + { + IRInst* operands[] = {dest, operand}; + builder.emitCallInst(builder.getVoidType(), func, 2, operands); + } + else + { + auto val = builder.emitIntrinsicInst( + tryGetPointedToType(&builder, dest->getDataType()), + op, + 1, + &operand); + builder.emitStore(dest, val); + } } }; @@ -131,21 +156,23 @@ struct LoweredElementTypeContext IRFunc* createMatrixUnpackFunc( IRMatrixType* matrixType, IRStructType* structType, - IRStructKey* dataKey, - IRArrayType* arrayType) + IRStructKey* dataKey) { IRBuilder builder(structType); builder.setInsertAfter(structType); auto func = builder.createFunc(); - auto funcType = builder.getFuncType(1, (IRType**)&structType, matrixType); + auto refStructType = builder.getRefType(structType, AddressSpace::Generic); + auto funcType = builder.getFuncType(1, (IRType**)&refStructType, matrixType); func->setFullType(funcType); builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage")); + builder.addForceInlineDecoration(func); builder.setInsertInto(func); builder.emitBlock(); auto rowCount = (Index)getIntVal(matrixType->getRowCount()); auto colCount = (Index)getIntVal(matrixType->getColumnCount()); - auto packedParam = builder.emitParam(structType); - auto vectorArray = builder.emitFieldExtract(arrayType, packedParam, dataKey); + auto packedParamRef = builder.emitParam(refStructType); + auto packedParam = builder.emitLoad(packedParamRef); + auto vectorArray = builder.emitFieldExtract(packedParam, dataKey); List args; args.setCount(rowCount * colCount); if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) @@ -187,13 +214,17 @@ struct LoweredElementTypeContext IRBuilder builder(structType); builder.setInsertAfter(structType); auto func = builder.createFunc(); - auto funcType = builder.getFuncType(1, (IRType**)&matrixType, structType); + auto outStructType = builder.getRefType(structType, AddressSpace::Generic); + IRType* paramTypes[] = {outStructType, matrixType}; + auto funcType = builder.getFuncType(2, paramTypes, builder.getVoidType()); func->setFullType(funcType); builder.addNameHintDecoration(func, UnownedStringSlice("packMatrix")); + builder.addForceInlineDecoration(func); builder.setInsertInto(func); builder.emitBlock(); auto rowCount = getIntVal(matrixType->getRowCount()); auto colCount = getIntVal(matrixType->getColumnCount()); + auto outParam = builder.emitParam(outStructType); auto originalParam = builder.emitParam(matrixType); List elements; elements.setCount((Index)(rowCount * colCount)); @@ -255,7 +286,8 @@ struct LoweredElementTypeContext auto vectorArray = builder.emitMakeArray(arrayType, (UInt)vectors.getCount(), vectors.getBuffer()); auto result = builder.emitMakeStruct(structType, 1, &vectorArray); - builder.emitReturn(result); + builder.emitStore(outParam, result); + builder.emitReturn(); return func; } @@ -263,19 +295,20 @@ struct LoweredElementTypeContext IRArrayType* arrayType, IRStructType* structType, IRStructKey* dataKey, - IRArrayType* innerArrayType, LoweredElementTypeInfo innerTypeInfo) { IRBuilder builder(structType); builder.setInsertAfter(structType); auto func = builder.createFunc(); - auto funcType = builder.getFuncType(1, (IRType**)&structType, arrayType); + auto refStructType = builder.getRefType(structType, AddressSpace::Generic); + auto funcType = builder.getFuncType(1, (IRType**)&refStructType, arrayType); func->setFullType(funcType); builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage")); + builder.addForceInlineDecoration(func); builder.setInsertInto(func); builder.emitBlock(); - auto packedParam = builder.emitParam(structType); - auto packedArray = builder.emitFieldExtract(innerArrayType, packedParam, dataKey); + auto packedParam = builder.emitParam(refStructType); + auto packedArray = builder.emitFieldAddress(packedParam, dataKey); auto count = getIntVal(arrayType->getElementCount()); IRInst* result = nullptr; if (count <= kMaxArraySizeToUnroll) @@ -285,11 +318,11 @@ struct LoweredElementTypeContext args.setCount((Index)count); for (IRIntegerValue ii = 0; ii < count; ++ii) { - auto packedElement = builder.emitElementExtract(packedArray, ii); + auto packedElementAddr = builder.emitElementAddress(packedArray, ii); auto originalElement = innerTypeInfo.convertLoweredToOriginal.apply( builder, innerTypeInfo.originalType, - packedElement); + packedElementAddr); args[(Index)ii] = originalElement; } result = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer()); @@ -308,11 +341,11 @@ struct LoweredElementTypeContext loopBreakBlock); builder.setInsertBefore(loopBodyBlock->getFirstOrdinaryInst()); - auto packedElement = builder.emitElementExtract(packedArray, loopParam); + auto packedElementAddr = builder.emitElementAddress(packedArray, loopParam); auto originalElement = innerTypeInfo.convertLoweredToOriginal.apply( builder, innerTypeInfo.originalType, - packedElement); + packedElementAddr); auto varPtr = builder.emitElementAddress(resultVar, loopParam); builder.emitStore(varPtr, originalElement); builder.setInsertInto(loopBreakBlock); @@ -325,20 +358,24 @@ struct LoweredElementTypeContext IRFunc* createArrayPackFunc( IRArrayType* arrayType, IRStructType* structType, - IRArrayType* innerArrayType, + IRStructKey* arrayStructKey, LoweredElementTypeInfo innerTypeInfo) { IRBuilder builder(structType); builder.setInsertAfter(structType); auto func = builder.createFunc(); - auto funcType = builder.getFuncType(1, (IRType**)&arrayType, structType); + auto outLoweredType = builder.getRefType(structType, AddressSpace::Generic); + IRType* paramTypes[] = {outLoweredType, structType}; + auto funcType = builder.getFuncType(2, paramTypes, builder.getVoidType()); func->setFullType(funcType); builder.addNameHintDecoration(func, UnownedStringSlice("packStorage")); + builder.addForceInlineDecoration(func); builder.setInsertInto(func); builder.emitBlock(); + auto outParam = builder.emitParam(outLoweredType); auto originalParam = builder.emitParam(arrayType); - IRInst* packedArray = nullptr; auto count = getIntVal(arrayType->getElementCount()); + auto destArray = builder.emitFieldAddress(outParam, arrayStructKey); if (count <= kMaxArraySizeToUnroll) { // If the array is small enough, just process each element directly. @@ -347,19 +384,16 @@ struct LoweredElementTypeContext for (IRIntegerValue ii = 0; ii < count; ++ii) { auto originalElement = builder.emitElementExtract(originalParam, ii); - auto packedElement = innerTypeInfo.convertOriginalToLowered.apply( + auto destArrayElement = builder.emitElementAddress(destArray, ii); + innerTypeInfo.convertOriginalToLowered.applyDestinationDriven( builder, - innerTypeInfo.loweredType, + destArrayElement, originalElement); - args[(Index)ii] = packedElement; } - packedArray = - builder.emitMakeArray(innerArrayType, (UInt)args.getCount(), args.getBuffer()); } else { // The general case for large arrays is to emit a loop through the elements. - IRVar* packedArrayVar = builder.emitVar(innerArrayType); IRBlock* loopBodyBlock; IRBlock* loopBreakBlock; auto loopParam = emitLoopBlocks( @@ -371,18 +405,14 @@ struct LoweredElementTypeContext builder.setInsertBefore(loopBodyBlock->getFirstOrdinaryInst()); auto originalElement = builder.emitElementExtract(originalParam, loopParam); - auto packedElement = innerTypeInfo.convertOriginalToLowered.apply( + auto varPtr = builder.emitElementAddress(destArray, loopParam); + innerTypeInfo.convertOriginalToLowered.applyDestinationDriven( builder, - innerTypeInfo.loweredType, + varPtr, originalElement); - auto varPtr = builder.emitElementAddress(packedArrayVar, loopParam); - builder.emitStore(varPtr, packedElement); builder.setInsertInto(loopBreakBlock); - packedArray = builder.emitLoad(packedArrayVar); } - - auto result = builder.emitMakeStruct(structType, 1, &packedArray); - builder.emitReturn(result); + builder.emitReturn(); return func; } @@ -451,6 +481,8 @@ struct LoweredElementTypeContext } auto loweredType = builder.createStructType(); + builder.addPhysicalTypeDecoration(loweredType); + StringBuilder nameSB; bool isColMajor = getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR; @@ -494,14 +526,14 @@ struct LoweredElementTypeContext info.loweredInnerArrayType = arrayType; info.loweredInnerStructKey = structKey; info.convertLoweredToOriginal = - createMatrixUnpackFunc(matrixType, loweredType, structKey, arrayType); + createMatrixUnpackFunc(matrixType, loweredType, structKey); info.convertOriginalToLowered = createMatrixPackFunc(matrixType, loweredType, vectorType, arrayType); return info; } - else if (auto arrayType = as(type)) + else if (auto arrayTypeBase = as(type)) { - auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayType->getElementType(), config); + auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayTypeBase->getElementType(), config); if (config.layoutRule->ruleName == IRTypeLayoutRuleName::Std140 && options.use16ByteArrayElementForConstantBuffer) @@ -560,42 +592,59 @@ struct LoweredElementTypeContext } } - auto loweredType = builder.createStructType(); - info.loweredType = loweredType; - StringBuilder nameSB; - nameSB << "_Array_" << getLayoutName(config.layoutRule->ruleName) << "_"; - getTypeNameHint(nameSB, arrayType->getElementType()); - nameSB << getIntVal(arrayType->getElementCount()); - builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); - auto structKey = builder.createStructKey(); - builder.addNameHintDecoration(structKey, UnownedStringSlice("data")); - IRSizeAndAlignment elementSizeAlignment; - getSizeAndAlignment( - target->getOptionSet(), - config.layoutRule, - loweredInnerTypeInfo.loweredType, - &elementSizeAlignment); - elementSizeAlignment = config.layoutRule->alignCompositeElement(elementSizeAlignment); - auto innerArrayType = builder.getArrayType( - loweredInnerTypeInfo.loweredType, - arrayType->getElementCount(), - builder.getIntValue(builder.getIntType(), elementSizeAlignment.getStride())); - builder.createStructField(loweredType, structKey, innerArrayType); - info.loweredInnerArrayType = innerArrayType; - info.loweredInnerStructKey = structKey; - info.convertLoweredToOriginal = createArrayUnpackFunc( - arrayType, - loweredType, - structKey, - innerArrayType, - loweredInnerTypeInfo); - info.convertOriginalToLowered = - createArrayPackFunc(arrayType, loweredType, innerArrayType, loweredInnerTypeInfo); - return info; - } - else if (as(type)) - { - info.loweredType = builder.getVoidType(); + auto arrayType = as(arrayTypeBase); + if (arrayType) + { + auto loweredType = builder.createStructType(); + builder.addPhysicalTypeDecoration(loweredType); + + info.loweredType = loweredType; + StringBuilder nameSB; + nameSB << "_Array_" << getLayoutName(config.layoutRule->ruleName) << "_"; + getTypeNameHint(nameSB, arrayType->getElementType()); + nameSB << getIntVal(arrayType->getElementCount()); + builder.addNameHintDecoration( + loweredType, + nameSB.produceString().getUnownedSlice()); + auto structKey = builder.createStructKey(); + builder.addNameHintDecoration(structKey, UnownedStringSlice("data")); + IRSizeAndAlignment elementSizeAlignment; + getSizeAndAlignment( + target->getOptionSet(), + config.layoutRule, + loweredInnerTypeInfo.loweredType, + &elementSizeAlignment); + elementSizeAlignment = + config.layoutRule->alignCompositeElement(elementSizeAlignment); + auto innerArrayType = builder.getArrayType( + loweredInnerTypeInfo.loweredType, + arrayType->getElementCount(), + builder.getIntValue(builder.getIntType(), elementSizeAlignment.getStride())); + builder.createStructField(loweredType, structKey, innerArrayType); + info.loweredInnerArrayType = innerArrayType; + info.loweredInnerStructKey = structKey; + info.convertLoweredToOriginal = + createArrayUnpackFunc(arrayType, loweredType, structKey, loweredInnerTypeInfo); + info.convertOriginalToLowered = + createArrayPackFunc(arrayType, loweredType, structKey, loweredInnerTypeInfo); + } + else + { + IRSizeAndAlignment elementSizeAlignment; + getSizeAndAlignment( + target->getOptionSet(), + config.layoutRule, + loweredInnerTypeInfo.loweredType, + &elementSizeAlignment); + elementSizeAlignment = + config.layoutRule->alignCompositeElement(elementSizeAlignment); + auto innerArrayType = builder.getArrayTypeBase( + arrayTypeBase->getOp(), + loweredInnerTypeInfo.loweredType, + nullptr, + builder.getIntValue(builder.getIntType(), elementSizeAlignment.getStride())); + info.loweredType = innerArrayType; + } return info; } else if (auto structType = as(type)) @@ -625,6 +674,8 @@ struct LoweredElementTypeContext } } auto loweredType = builder.createStructType(); + builder.addPhysicalTypeDecoration(loweredType); + StringBuilder nameSB; getTypeNameHint(nameSB, type); nameSB << "_" << getLayoutName(config.layoutRule->ruleName); @@ -635,12 +686,15 @@ struct LoweredElementTypeContext Index fieldId = 0; for (auto field : structType->getFields()) { - if (as(fieldLoweredTypeInfo[fieldId].loweredType)) + auto& loweredFieldTypeInfo = fieldLoweredTypeInfo[fieldId]; + // When lowering type for user pointer, skip fields that are unsized array. + if (config.addressSpace == AddressSpace::UserPointer && + as(loweredFieldTypeInfo.loweredType)) { fieldId++; + loweredFieldTypeInfo.loweredType = builder.getVoidType(); continue; } - auto loweredFieldTypeInfo = fieldLoweredTypeInfo[fieldId]; builder.createStructField( loweredType, field->getKey(), @@ -657,10 +711,12 @@ struct LoweredElementTypeContext builder.addNameHintDecoration( info.convertLoweredToOriginal.func, UnownedStringSlice("unpackStorage")); + builder.addForceInlineDecoration(info.convertLoweredToOriginal.func); + auto refLoweredType = builder.getRefType(loweredType, AddressSpace::Generic); info.convertLoweredToOriginal.func->setFullType( - builder.getFuncType(1, (IRType**)&loweredType, type)); + builder.getFuncType(1, (IRType**)&refLoweredType, type)); builder.emitBlock(); - auto loweredParam = builder.emitParam(loweredType); + auto loweredParam = builder.emitParam(refLoweredType); List args; Index fieldId = 0; for (auto field : structType->getFields()) @@ -670,10 +726,7 @@ struct LoweredElementTypeContext fieldId++; continue; } - auto storageField = builder.emitFieldExtract( - fieldLoweredTypeInfo[fieldId].loweredType, - loweredParam, - field->getKey()); + auto storageField = builder.emitFieldAddress(loweredParam, field->getKey()); auto unpackedField = fieldLoweredTypeInfo[fieldId].convertLoweredToOriginal.apply( builder, @@ -694,9 +747,14 @@ struct LoweredElementTypeContext builder.addNameHintDecoration( info.convertOriginalToLowered.func, UnownedStringSlice("packStorage")); + builder.addForceInlineDecoration(info.convertOriginalToLowered.func); + + auto outLoweredType = builder.getRefType(loweredType, AddressSpace::Generic); + IRType* paramTypes[] = {outLoweredType, type}; info.convertOriginalToLowered.func->setFullType( - builder.getFuncType(1, (IRType**)&type, loweredType)); + builder.getFuncType(2, paramTypes, builder.getVoidType())); builder.emitBlock(); + auto outParam = builder.emitParam(outLoweredType); auto param = builder.emitParam(type); List args; Index fieldId = 0; @@ -709,15 +767,15 @@ struct LoweredElementTypeContext } auto fieldVal = builder.emitFieldExtract(field->getFieldType(), param, field->getKey()); - auto packedField = fieldLoweredTypeInfo[fieldId].convertOriginalToLowered.apply( + auto destAddr = builder.emitFieldAddress(outParam, field->getKey()); + + fieldLoweredTypeInfo[fieldId].convertOriginalToLowered.applyDestinationDriven( builder, - fieldLoweredTypeInfo[fieldId].loweredType, + destAddr, fieldVal); - args.add(packedField); fieldId++; } - auto result = builder.emitMakeStruct(loweredType, args); - builder.emitReturn(result); + builder.emitReturn(); } return info; @@ -743,37 +801,8 @@ struct LoweredElementTypeContext info.loweredType = builder.getVectorType( info.loweredType, vectorType->getElementCount()); - // Create unpack func. - { - builder.setInsertAfter(type); - info.convertLoweredToOriginal = builder.createFunc(); - builder.setInsertInto(info.convertLoweredToOriginal.func); - builder.addNameHintDecoration( - info.convertLoweredToOriginal.func, - UnownedStringSlice("unpackStorage")); - info.convertLoweredToOriginal.func->setFullType( - builder.getFuncType(1, (IRType**)&info.loweredType, type)); - builder.emitBlock(); - auto loweredParam = builder.emitParam(info.loweredType); - auto result = builder.emitCast(type, loweredParam); - builder.emitReturn(result); - } - - // Create pack func. - { - builder.setInsertAfter(info.convertLoweredToOriginal.func); - info.convertOriginalToLowered = builder.createFunc(); - builder.setInsertInto(info.convertOriginalToLowered.func); - builder.addNameHintDecoration( - info.convertOriginalToLowered.func, - UnownedStringSlice("packStorage")); - info.convertOriginalToLowered.func->setFullType( - builder.getFuncType(1, (IRType**)&type, info.loweredType)); - builder.emitBlock(); - auto param = builder.emitParam(type); - auto result = builder.emitCast(info.loweredType, param); - builder.emitReturn(result); - } + info.convertLoweredToOriginal = kIROp_BuiltinCast; + info.convertOriginalToLowered = kIROp_BuiltinCast; return info; } } @@ -828,7 +857,8 @@ struct LoweredElementTypeContext IRType* getLoweredPtrLikeType(IRType* originalPtrLikeType, IRType* newElementType) { if (as(originalPtrLikeType) || as(originalPtrLikeType) || - as(originalPtrLikeType)) + as(originalPtrLikeType) || + as(originalPtrLikeType)) { IRBuilder builder(newElementType); builder.setInsertAfter(newElementType); @@ -859,6 +889,26 @@ struct LoweredElementTypeContext TypeLoweringConfig config; }; + IRInst* getBufferAddr(IRBuilder& builder, IRInst* loadStoreInst) + { + switch (loadStoreInst->getOp()) + { + case kIROp_Load: + case kIROp_Store: + return loadStoreInst->getOperand(0); + case kIROp_StructuredBufferLoad: + case kIROp_StructuredBufferLoadStatus: + case kIROp_RWStructuredBufferLoad: + case kIROp_RWStructuredBufferLoadStatus: + case kIROp_RWStructuredBufferStore: + return builder.emitRWStructuredBufferGetElementPtr( + loadStoreInst->getOperand(0), + loadStoreInst->getOperand(1)); + default: + return nullptr; + } + } + void processModule(IRModule* module) { IRBuilder builder(module); @@ -891,6 +941,8 @@ struct LoweredElementTypeContext elementType = structBuffer->getElementType(); else if (auto constBuffer = as(globalInst)) elementType = constBuffer->getElementType(); + else if (auto storageBuffer = as(globalInst)) + elementType = storageBuffer->getElementType(); if (as(globalInst)) continue; if (!as(elementType) && !as(elementType) && @@ -908,6 +960,10 @@ struct LoweredElementTypeContext { auto bufferType = bufferTypeInfo.bufferType; auto elementType = bufferTypeInfo.elementType; + + if (elementType->findDecoration()) + continue; + auto config = getTypeLoweringConfigForBuffer(target, bufferType); auto loweredBufferElementTypeInfo = getLoweredTypeInfo(elementType, config); @@ -954,8 +1010,13 @@ struct LoweredElementTypeContext // getOffsetPtr(trailingPtr, index). if (auto fieldAddr = as(ptrVal)) { - if (auto ptrType = as(ptrVal->getDataType())) + auto handleUnsizedArrayAccess = [&]() -> bool { + auto ptrType = as(ptrVal->getDataType()); + if (!ptrType) + return false; + if (ptrType->getAddressSpace() != AddressSpace::UserPointer) + return false; if (auto unsizedArrayType = as(ptrType->getValueType())) { builder.setInsertBefore(ptrVal); @@ -1019,9 +1080,12 @@ struct LoweredElementTypeContext }); SLANG_ASSERT(!ptrVal->hasUses()); ptrVal->removeAndDeallocate(); - continue; + return true; } - } + return false; + }; + if (handleUnsizedArrayAccess()) + continue; } LoweredElementTypeInfo loweredElementTypeInfo = {}; @@ -1060,7 +1124,7 @@ struct LoweredElementTypeContext getLoweredTypeInfo((IRType*)originalElementType, config); } - if (!loweredElementTypeInfo.convertLoweredToOriginal) + if (loweredElementTypeInfo.loweredType == loweredElementTypeInfo.originalType) continue; ptrVal->setFullType(getLoweredPtrLikeType( @@ -1083,15 +1147,28 @@ struct LoweredElementTypeContext case kIROp_RWStructuredBufferLoadStatus: case kIROp_StructuredBufferConsume: { - IRCloneEnv cloneEnv = {}; builder.setInsertBefore(user); - auto newLoad = cloneInst(&cloneEnv, &builder, user); - newLoad->setFullType(loweredElementTypeInfo.loweredType); + auto addr = getBufferAddr(builder, user); + if (!addr) + { + IRCloneEnv cloneEnv = {}; + builder.setInsertBefore(user); + auto newLoad = cloneInst(&cloneEnv, &builder, user); + newLoad->setFullType(loweredElementTypeInfo.loweredType); + addr = builder.emitVar(loweredElementTypeInfo.loweredType); + builder.emitStore(addr, newLoad); + } + if (auto alignedAttr = user->findAttr()) + { + builder.addAlignedAddressDecoration( + addr, + alignedAttr->getAlignment()); + } auto unpackedVal = loweredElementTypeInfo.convertLoweredToOriginal.apply( builder, loweredElementTypeInfo.originalType, - newLoad); + addr); user->replaceUsesWith(unpackedVal); user->removeAndDeallocate(); break; @@ -1106,19 +1183,33 @@ struct LoweredElementTypeContext IRCloneEnv cloneEnv = {}; builder.setInsertBefore(user); auto originalVal = getStoreVal(user); - auto packedVal = - loweredElementTypeInfo.convertOriginalToLowered.apply( - builder, - loweredElementTypeInfo.loweredType, - originalVal); - if (auto store = as(user)) - store->val.set(packedVal); - else if (auto sbStore = as(user)) - sbStore->setOperand(2, packedVal); + IRInst* addr = getBufferAddr(builder, user); + if (addr) + { + if (auto alignedAttr = user->findAttr()) + { + builder.addAlignedAddressDecoration( + addr, + alignedAttr->getAlignment()); + } + + loweredElementTypeInfo.convertOriginalToLowered + .applyDestinationDriven(builder, addr, originalVal); + user->removeAndDeallocate(); + } else if (auto sbAppend = as(user)) + { + builder.setInsertBefore(sbAppend); + addr = builder.emitVar(loweredElementTypeInfo.loweredType); + loweredElementTypeInfo.convertOriginalToLowered + .applyDestinationDriven(builder, addr, originalVal); + auto packedVal = builder.emitLoad(addr); sbAppend->setOperand(1, packedVal); + } else + { SLANG_UNREACHABLE("unhandled store type"); + } break; } case kIROp_GetElementPtr: @@ -1176,24 +1267,18 @@ struct LoweredElementTypeContext // access, we need to materialize the object as a local variable, // and pass the address of the local variable to the function. builder.setInsertBefore(user); - auto newLoad = - builder.emitLoad(loweredElementTypeInfo.loweredType, ptrVal); auto unpackedVal = loweredElementTypeInfo.convertLoweredToOriginal.apply( builder, (IRType*)originalElementType, - newLoad); + ptrVal); auto var = builder.emitVar((IRType*)originalElementType); builder.emitStore(var, unpackedVal); use->set(var); builder.setInsertAfter(user); auto newVal = builder.emitLoad(var); - auto packedVal = - loweredElementTypeInfo.convertOriginalToLowered.apply( - builder, - (IRType*)loweredElementTypeInfo.loweredType, - newVal); - builder.emitStore(ptrVal, packedVal); + loweredElementTypeInfo.convertOriginalToLowered + .applyDestinationDriven(builder, ptrVal, newVal); } break; default: @@ -1355,6 +1440,21 @@ void lowerBufferElementTypeToStorageType( context.processModule(module); } +IRTypeLayoutRules* getTypeLayoutRulesFromOp(IROp layoutTypeOp, IRTypeLayoutRules* defaultLayout) +{ + switch (layoutTypeOp) + { + case kIROp_DefaultBufferLayoutType: + return defaultLayout; + case kIROp_Std140BufferLayoutType: + return IRTypeLayoutRules::getStd140(); + case kIROp_Std430BufferLayoutType: + return IRTypeLayoutRules::getStd430(); + case kIROp_ScalarBufferLayoutType: + return IRTypeLayoutRules::getNatural(); + } + return defaultLayout; +} IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* bufferType) { @@ -1395,18 +1495,7 @@ IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* buf auto layoutTypeOp = structBufferType->getDataLayout() ? structBufferType->getDataLayout()->getOp() : kIROp_DefaultBufferLayoutType; - switch (layoutTypeOp) - { - case kIROp_DefaultBufferLayoutType: - return IRTypeLayoutRules::getStd430(); - case kIROp_Std140BufferLayoutType: - return IRTypeLayoutRules::getStd140(); - case kIROp_Std430BufferLayoutType: - return IRTypeLayoutRules::getStd430(); - case kIROp_ScalarBufferLayoutType: - return IRTypeLayoutRules::getNatural(); - } - return IRTypeLayoutRules::getStd430(); + return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRules::getStd430()); } case kIROp_ConstantBufferType: case kIROp_ParameterBlockType: @@ -1416,18 +1505,15 @@ IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* buf auto layoutTypeOp = parameterGroupType->getDataLayout() ? parameterGroupType->getDataLayout()->getOp() : kIROp_DefaultBufferLayoutType; - switch (layoutTypeOp) - { - case kIROp_DefaultBufferLayoutType: - return IRTypeLayoutRules::getStd140(); - case kIROp_Std140BufferLayoutType: - return IRTypeLayoutRules::getStd140(); - case kIROp_Std430BufferLayoutType: - return IRTypeLayoutRules::getStd430(); - case kIROp_ScalarBufferLayoutType: - return IRTypeLayoutRules::getNatural(); - } - return IRTypeLayoutRules::getStd140(); + return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRules::getStd140()); + } + case kIROp_GLSLShaderStorageBufferType: + { + auto storageBufferType = as(bufferType); + auto layoutTypeOp = storageBufferType->getDataLayout() + ? storageBufferType->getDataLayout()->getOp() + : kIROp_Std430BufferLayoutType; + return getTypeLayoutRulesFromOp(layoutTypeOp, IRTypeLayoutRules::getStd430()); } case kIROp_PtrType: return IRTypeLayoutRules::getNatural(); @@ -1446,6 +1532,9 @@ TypeLoweringConfig getTypeLoweringConfigForBuffer(TargetProgram* target, IRType* case AddressSpace::Output: addrSpace = AddressSpace::Input; break; + case AddressSpace::UserPointer: + addrSpace = AddressSpace::UserPointer; + break; } } auto rules = getTypeLayoutRuleForBuffer(target, bufferType); diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index b19af364e..2e2ba358f 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -9,6 +9,7 @@ #include "slang-ir-dominators.h" #include "slang-ir-float-non-uniform-resource-index.h" #include "slang-ir-glsl-legalize.h" +#include "slang-ir-inline.h" #include "slang-ir-insts.h" #include "slang-ir-layout.h" #include "slang-ir-legalize-global-values.h" @@ -132,6 +133,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase inst->getElementType(), builder.getIntValue(builder.getIntType(), elementSize.getStride())); const auto structType = builder.createStructType(); + builder.addPhysicalTypeDecoration(structType); const auto arrayKey = builder.createStructKey(); builder.createStructField(structType, arrayKey, arrayType); IRSizeAndAlignment structSize; @@ -213,6 +215,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase IRBuilder builder(cbParamInst); builder.setInsertBefore(cbParamInst); auto structType = builder.createStructType(); + builder.addPhysicalTypeDecoration(structType); addToWorkList(structType); StringBuilder sb; sb << "cbuffer_"; @@ -1850,6 +1853,120 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } }; + void propagateAddressAlignment() + { + // Work list of load/store insts to add Aligned attribute to. + List loadStoreInsts; + + for (auto globalInst : m_module->getGlobalInsts()) + { + auto func = as(globalInst); + if (!func) + continue; + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + switch (inst->getOp()) + { + case kIROp_GetElementPtr: + case kIROp_FieldAddress: + { + auto base = inst->getOperand(0); + auto ptrType = as(base->getDataType()); + if (!ptrType) + break; + // Propagate address alignment if possible. + auto alignDecor = base->findDecoration(); + if (!alignDecor) + break; + auto valueType = ptrType->getValueType(); + auto layout = valueType->findDecoration(); + if (!layout) + break; + auto alignment = getIntVal(alignDecor->getAlignment()); + if (inst->getOp() == kIROp_GetElementPtr) + { + if (alignment >= layout->getAlignment()) + { + IRBuilder builder(inst); + builder.addAlignedAddressDecoration( + inst, + alignDecor->getAlignment()); + } + } + else + { + IRTypeLayoutRuleName layoutRuleName = layout->getLayoutName(); + auto field = findStructField( + valueType, + (IRStructKey*)as(inst)->getField()); + if (!field) + break; + IRIntegerValue offset = 0; + if (getOffset( + m_sharedContext->m_targetProgram->getOptionSet(), + IRTypeLayoutRules::get(layoutRuleName), + field, + &offset) != SLANG_OK) + break; + if (offset % alignment == 0) + { + IRBuilder builder(inst); + builder.addAlignedAddressDecoration( + inst, + alignDecor->getAlignment()); + } + } + } + break; + case kIROp_Load: + case kIROp_Store: + { + if (inst->findAttr()) + break; + loadStoreInsts.add(inst); + } + break; + } + } + } + } + + // Process the work list. + // If load/store doesn't have Aligned attribute, and the ptr has + // a IRAlignedAddress decoration, we should create a load/store + // with a Aligned attribute. + for (auto inst : loadStoreInsts) + { + + if (auto load = as(inst)) + { + auto ptr = load->getPtr(); + if (auto decor = ptr->findDecoration()) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto newLoad = + builder.emitLoad(load->getFullType(), ptr, decor->getAlignment()); + load->replaceUsesWith(newLoad); + load->removeAndDeallocate(); + } + } + else if (auto store = as(inst)) + { + auto ptr = store->getPtr(); + if (auto decor = ptr->findDecoration()) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + builder.emitStore(ptr, store->getVal(), decor->getAlignment()); + store->removeAndDeallocate(); + } + } + } + } + void processModule() { determineSpirvVersion(); @@ -1914,8 +2031,8 @@ struct SPIRVLegalizationContext : public SourceEmitterBase legalizeSPIRVEntryPoint(func, entryPointDecor); } // SPIRV requires a dominator block to appear before dominated blocks. - // After legalizing the control flow, we need to sort our blocks to ensure this is - // true. + // After legalizing the control flow, we need to sort our blocks to ensure this + // is true. sortBlocksInFunc(func); } } @@ -1939,15 +2056,23 @@ struct SPIRVLegalizationContext : public SourceEmitterBase m_module, bufferElementTypeLoweringOptions); - // The above step may produce empty struct types, so we need to lower them out of existence. + // Inline all pack/unpack storage type functions generated during buffer element + // lowering pass. + performForceInlining(m_module); + + // The above step may produce empty struct types, so we need to lower them out of + // existence. legalizeEmptyTypes(m_sharedContext->m_targetProgram, m_module, m_sink); + // Propagate alignment hints on address instructions. + propagateAddressAlignment(); + // Specalize address space for all pointers. SpirvAddressSpaceAssigner addressSpaceAssigner; specializeAddressSpace(m_module, &addressSpaceAssigner); - // For SPIR-V, we don't skip this validation, because we might then be generating invalid - // SPIR-V. + // For SPIR-V, we don't skip this validation, because we might then be generating + // invalid SPIR-V. bool skipFuncParamValidation = false; validateAtomicOperations(skipFuncParamValidation, m_sink, m_module->getModuleInst()); } diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index f75a24ac6..4aef7e7ba 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -616,11 +616,6 @@ void getTypeNameHint(StringBuilder& sb, IRInst* type) case kIROp_HitObjectType: sb << "HitObject"; break; - case kIROp_HLSLConstBufferPointerType: - sb << "ConstantBufferPointer<"; - getTypeNameHint(sb, as(type)->getValueType()); - sb << ">"; - break; case kIROp_HLSLStructuredBufferType: sb << "StructuredBuffer<"; getTypeNameHint(sb, as(type)->getElementType()); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index b982385fa..72313217b 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2925,10 +2925,22 @@ IRComPtrType* IRBuilder::getComPtrType(IRType* valueType) return (IRComPtrType*)getType(kIROp_ComPtrType, valueType); } -IRArrayTypeBase* IRBuilder::getArrayTypeBase(IROp op, IRType* elementType, IRInst* elementCount) +IRArrayTypeBase* IRBuilder::getArrayTypeBase( + IROp op, + IRType* elementType, + IRInst* elementCount, + IRInst* stride) { - IRInst* operands[] = {elementType, elementCount}; - return (IRArrayTypeBase*)getType(op, op == kIROp_ArrayType ? 2 : 1, operands); + if (op == kIROp_ArrayType) + { + IRInst* operands[] = {elementType, elementCount, stride}; + return (IRArrayTypeBase*)getType(op, stride ? 3 : 2, operands); + } + else + { + IRInst* operands[] = {elementType, stride}; + return (IRArrayTypeBase*)getType(op, stride ? 2 : 1, operands); + } } IRArrayType* IRBuilder::getArrayType(IRType* elementType, IRInst* elementCount) @@ -4984,6 +4996,14 @@ IRInst* IRBuilder::emitLoad(IRType* type, IRInst* ptr) return inst; } +IRInst* IRBuilder::emitLoad(IRType* type, IRInst* ptr, IRInst* align) +{ + auto inst = createInst(this, kIROp_Load, type, ptr, getAttr(kIROp_AlignedAttr, align)); + + addInst(inst); + return inst; +} + IRInst* IRBuilder::emitLoad(IRInst* ptr) { // Note: a `load` operation does not consider the rate @@ -5023,6 +5043,20 @@ IRInst* IRBuilder::emitStore(IRInst* dstPtr, IRInst* srcVal) return inst; } +IRInst* IRBuilder::emitStore(IRInst* dstPtr, IRInst* srcVal, IRInst* align) +{ + auto inst = createInst( + this, + kIROp_Store, + nullptr, + dstPtr, + srcVal, + getAttr(kIROp_AlignedAttr, align)); + + addInst(inst); + return inst; +} + IRInst* IRBuilder::emitAtomicStore(IRInst* dstPtr, IRInst* srcVal, IRInst* memoryOrder) { auto inst = createInst( diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index ecf5d1c66..aa74c0704 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1816,12 +1816,6 @@ struct IRRTTIPointerType : IRRawPointerTypeBase IR_LEAF_ISA(RTTIPointerType) }; -struct IRHLSLConstBufferPointerType : IRPtrTypeBase -{ - IR_LEAF_ISA(HLSLConstBufferPointerType) - IRInst* getBaseAlignment() { return getOperand(1); } -}; - struct IRGlobalHashedStringLiterals : IRInst { IR_LEAF_ISA(GlobalHashedStringLiterals) diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 4d692b727..9ca22fead 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -870,8 +870,15 @@ LoweredValInfo emitCallToDeclRef( return LoweredValInfo::simple(args[0]); } auto intrinsicOp = getIntrinsicOp(funcDecl, intrinsicOpModifier); - return LoweredValInfo::simple( - builder->emitIntrinsicInst(type, IROp(intrinsicOp), argCount, args)); + switch (IROp(intrinsicOp)) + { + case kIROp_GetOffsetPtr: + SLANG_ASSERT(argCount == 2); + return LoweredValInfo::simple(builder->emitGetOffsetPtr(args[0], args[1])); + default: + return LoweredValInfo::simple( + builder->emitIntrinsicInst(type, IROp(intrinsicOp), argCount, args)); + } } // Fallback case is to emit an actual call. -- cgit v1.2.3