diff options
| author | Yong He <yonghe@outlook.com> | 2024-07-25 15:00:14 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-07-25 15:00:14 -0700 |
| commit | c9d89a40775a055873adf82cfb0ee1cb6bdcb93c (patch) | |
| tree | 2438f353e87b30febe966ca23976793637c018d2 /source | |
| parent | 1343ab79fcd0ff9e5ffebbcf95414e51ab19e9cd (diff) | |
Overhaul IR lowering of pointer types. (#4710)
* Overhaul IR lowering of pointer types.
* Propagate address space in IRBuilder.
* Fixup.
* Fix.
* Fix.
* Change how Ptr type is printed to text.
* Fix.
Diffstat (limited to 'source')
40 files changed, 326 insertions, 255 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 45b3435eb..c75d4735b 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -679,14 +679,14 @@ struct __none_t { }; -__generic<T> +__generic<T, let addrSpace : uint64_t = $( (uint64_t)AddressSpace::UserPointer)ULL> __magic_type(PtrType) __intrinsic_type($(kIROp_PtrType)) struct Ptr { __generic<U> __intrinsic_op($(kIROp_BitCast)) - __init(Ptr<U> ptr); + __init(Ptr<U, addrSpace> ptr); __intrinsic_op($(kIROp_CastIntToPtr)) __init(uint64_t val); @@ -703,53 +703,47 @@ struct Ptr }; __intrinsic_op($(kIROp_Load)) -T __load<T>(Ptr<T> ptr); +T __load<T, let addrSpace : uint64_t>(Ptr<T, addrSpace> ptr); __intrinsic_op($(kIROp_Store)) -void __store<T>(Ptr<T> ptr, T val); - -__intrinsic_op($(kIROp_GetElementPtr)) -Ptr<T> __getElementPtr<T>(Ptr<T> ptr, int index); +void __store<T, let addrSpace : uint64_t>(Ptr<T, addrSpace> ptr, T val); __intrinsic_op($(kIROp_GetElementPtr)) -Ptr<T> __getElementPtr<T>(Ptr<T> ptr, int64_t index); +Ptr<T, addrSpace> __getElementPtr<T, let addrSpace : uint64_t, TIndex : __BuiltinIntegerType>(Ptr<T, addrSpace> ptr, TIndex index); __intrinsic_op($(kIROp_GetOffsetPtr)) -Ptr<T> __getOffsetPtr<T>(Ptr<T> ptr, int index); +Ptr<T, addrSpace> __getOffsetPtr<T, let addrSpace : uint64_t, TIndex : __BuiltinIntegerType>(Ptr<T, addrSpace> ptr, TIndex index); -__intrinsic_op($(kIROp_GetOffsetPtr)) -Ptr<T> __getOffsetPtr<T>(Ptr<T> ptr, int64_t index); - -__generic<T> +__generic<T, let addrSpace : uint64_t> __intrinsic_op($(kIROp_Less)) -bool operator<(Ptr<T> p1, Ptr<T> p2); +bool operator <(Ptr<T, addrSpace> p1, Ptr<T, addrSpace> p2); -__generic<T> +__generic<T, let addrSpace : uint64_t> __intrinsic_op($(kIROp_Leq)) -bool operator<=(Ptr<T> p1, Ptr<T> p2); +bool operator <=(Ptr<T, addrSpace> p1, Ptr<T, addrSpace> p2); -__generic<T> +__generic<T, let addrSpace : uint64_t> __intrinsic_op($(kIROp_Greater)) -bool operator>(Ptr<T> p1, Ptr<T> p2); +bool operator>(Ptr<T, addrSpace> p1, Ptr<T, addrSpace> p2); -__generic<T> +__generic<T, let addrSpace : uint64_t> __intrinsic_op($(kIROp_Geq)) -bool operator>=(Ptr<T> p1, Ptr<T> p2); +bool operator >=(Ptr<T, addrSpace> p1, Ptr<T, addrSpace> p2); -__generic<T> +__generic<T, let addrSpace : uint64_t> __intrinsic_op($(kIROp_Neq)) -bool operator!=(Ptr<T> p1, Ptr<T> p2); +bool operator !=(Ptr<T, addrSpace> p1, Ptr<T, addrSpace> p2); -__generic<T> +__generic<T, let addrSpace : uint64_t> __intrinsic_op($(kIROp_Eql)) -bool operator==(Ptr<T> p1, Ptr<T> p2); +bool operator ==(Ptr<T, addrSpace> p1, Ptr<T, addrSpace> p2); extension bool : IRangedValue { - __generic<T> + __generic<T, let addrSpace : uint64_t> __implicit_conversion($(kConversionCost_PtrToBool)) __intrinsic_op($(kIROp_CastPtrToBool)) - __init(Ptr<T> ptr); + __init(Ptr<T, addrSpace> ptr); __generic<T : __EnumType> __implicit_conversion($(kConversionCost_IntegerTruncate)) @@ -765,9 +759,9 @@ extension bool : IRangedValue extension uint64_t : IRangedValue { - __generic<T> + __generic<T, let addrSpace : uint64_t> __intrinsic_op($(kIROp_CastPtrToInt)) - __init(Ptr<T> ptr); + __init(Ptr<T, addrSpace> ptr); static const uint64_t maxValue = 0xFFFFFFFFFFFFFFFFULL; static const uint64_t minValue = 0; @@ -775,9 +769,9 @@ extension uint64_t : IRangedValue extension int64_t : IRangedValue { - __generic<T> + __generic<T, let addrSpace : uint64_t> __intrinsic_op($(kIROp_CastPtrToInt)) - __init(Ptr<T> ptr); + __init(Ptr<T, addrSpace> ptr); static const int64_t maxValue = 0x7FFFFFFFFFFFFFFFLL; static const int64_t minValue = -0x8000000000000000LL; @@ -785,9 +779,9 @@ extension int64_t : IRangedValue extension intptr_t : IRangedValue { - __generic<T> + __generic<T, let addrSpace : uint64_t> __intrinsic_op($(kIROp_CastPtrToInt)) - __init(Ptr<T> ptr); + __init(Ptr<T, addrSpace> ptr); static const intptr_t maxValue = $(SLANG_PROCESSOR_X86_64?"0x7FFFFFFFFFFFFFFFz":"0x7FFFFFFFz"); static const intptr_t minValue = $(SLANG_PROCESSOR_X86_64?"0x8000000000000000z":"0x80000000z"); static const int size = $(SLANG_PROCESSOR_X86_64?"8":"4"); @@ -795,9 +789,9 @@ extension intptr_t : IRangedValue extension uintptr_t : IRangedValue { - __generic<T> + __generic<T, let addrSpace : uint64_t> __intrinsic_op($(kIROp_CastPtrToInt)) - __init(Ptr<T> ptr); + __init(Ptr<T, addrSpace> ptr); static const uintptr_t maxValue = $(SLANG_PROCESSOR_X86_64?"0xFFFFFFFFFFFFFFFFz":"0xFFFFFFFFz"); static const uintptr_t minValue = 0z; static const int size = $(SLANG_PROCESSOR_X86_64?"8":"4"); @@ -827,6 +821,8 @@ __intrinsic_type($(kIROp_ConstRefType)) struct ConstRef {}; +typealias __Addr<T> = Ptr<T, $( (uint64_t)AddressSpace::Generic)ULL>; + __generic<T> __magic_type(OptionalType) __intrinsic_type($(kIROp_OptionalType)) @@ -969,10 +965,10 @@ extension Ptr<void> [__unsafeForceInlineEarly] __init(NativeString nativeStr) { this = nativeStr.getBuffer(); } - __generic<T> + __generic<T, let addrSpace : uint64_t> __intrinsic_op(0) __implicit_conversion($(kConversionCost_PtrToVoidPtr)) - __init(Ptr<T> ptr); + __init(Ptr<T, addrSpace> ptr); __generic<T> __intrinsic_op(0) @@ -1628,21 +1624,28 @@ for (auto op : intrinsicUnaryOps) }}}} -__generic<T> +__generic<T, let addrSpace : uint64_t> __intrinsic_op(0) -__prefix Ref<T> operator*(Ptr<T> value); +[require(cpp_cuda_spirv)] +__prefix Ref<T> operator*(Ptr<T, addrSpace> value); __generic<T> __intrinsic_op(0) -__prefix Ptr<T> operator&(__ref T value); +[require(cpp_cuda_spirv)] +__prefix Ptr<T, $( (uint64_t)AddressSpace::UserPointer)ULL> operator&(__ref T value); __generic<T> +__intrinsic_op(0) +[require(cpp_cuda_spirv)] +__Addr<T> __get_addr( __ref T value); + +__generic<T, let addrSpace : uint64_t> __intrinsic_op($(kIROp_GetOffsetPtr)) -Ptr<T> operator+(Ptr<T> value, int64_t offset); +Ptr<T, addrSpace> operator+(Ptr<T, addrSpace> value, int64_t offset); -__generic<T> +__generic<T, let addrSpace : uint64_t> [__unsafeForceInlineEarly] -Ptr<T> operator-(Ptr<T> value, int64_t offset) +Ptr<T, addrSpace> operator -(Ptr<T, addrSpace> value, int64_t offset) { return __getOffsetPtr(value, -offset); } @@ -1707,9 +1710,9 @@ matrix<T,R,C> operator$(op.name)(in out matrix<T,R,C,L> value) {$(fixity.bodyPrefix) value = value $(op.binOp) T(1); return $(fixity.returnVal); } $(fixity.qual) -__generic<T> +__generic<T, let addrSpace : uint64_t> [__unsafeForceInlineEarly] -Ptr<T> operator$(op.name)(in out Ptr<T> value) +Ptr<T, addrSpace> operator$(op.name)(in out Ptr<T, addrSpace> value) {$(fixity.bodyPrefix) value = value $(op.binOp) 1; return $(fixity.returnVal); } ${{{{ @@ -2500,8 +2503,8 @@ __intrinsic_op($(kIROp_TreatAsDynamicUniform)) T asDynamicUniform<T>(T v); __generic<T> -__intrinsic_op($(kIROp_GetLegalizedSPIRVGlobalParamAddr)) -Ptr<T> __getLegalizedSPIRVGlobalParamAddr(T val); +__intrinsic_op( $(kIROp_GetLegalizedSPIRVGlobalParamAddr)) +__Addr<T> __getLegalizedSPIRVGlobalParamAddr(T val); __intrinsic_op($(kIROp_RequireComputeDerivative)) void __requireComputeDerivative(); diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 1a112e1a9..0539097eb 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -17568,11 +17568,11 @@ Ref<T> __hitObjectAttributes<T>() return t; } [ForceInline] -Ptr<T> __allocHitObjectAttributes<T>() +__Addr<T> __allocHitObjectAttributes<T>() { - [__vulkanHitObjectAttributes] + [__vulkanHitObjectAttributes] static T t; - return &t; + return __get_addr(t); } // Next is the custom intrinsic that will compute the hitObjectAttributes location @@ -17840,7 +17840,7 @@ struct HitObject case spirv: { // Save the attributes - Ptr<attr_t> attr = __allocHitObjectAttributes<attr_t>(); + __Addr<attr_t> attr = __allocHitObjectAttributes<attr_t>(); *attr = attributes; @@ -17914,7 +17914,7 @@ struct HitObject case spirv: { // Save the attributes - Ptr<attr_t> attr = __allocHitObjectAttributes<attr_t>(); + __Addr<attr_t> attr = __allocHitObjectAttributes<attr_t>(); *attr = attributes; @@ -18004,7 +18004,7 @@ struct HitObject case spirv: { // Save the attributes - Ptr<attr_t> attr = __allocHitObjectAttributes<attr_t>(); + __Addr<attr_t> attr = __allocHitObjectAttributes<attr_t>(); *attr = attributes; let origin = Ray.Origin; let direction = Ray.Direction; @@ -18071,7 +18071,7 @@ struct HitObject case spirv: { // Save the attributes - Ptr<attr_t> attr = __allocHitObjectAttributes<attr_t>(); + __Addr<attr_t> attr = __allocHitObjectAttributes<attr_t>(); *attr = attributes; let origin = Ray.Origin; let direction = Ray.Direction; @@ -18618,7 +18618,7 @@ struct HitObject } case spirv: { - Ptr<attr_t> attr = __allocHitObjectAttributes<attr_t>(); + __Addr<attr_t> attr = __allocHitObjectAttributes<attr_t>(); spirv_asm { OpExtension "SPV_NV_shader_invocation_reorder"; diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index a672c1b7e..ce4c32c3a 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -290,9 +290,9 @@ Type* ASTBuilder::getSpecializedBuiltinType(ArrayView<Val*> genericArgs, const c return rsType; } -PtrType* ASTBuilder::getPtrType(Type* valueType) +PtrType* ASTBuilder::getPtrType(Type* valueType, AddressSpace addrSpace) { - return dynamicCast<PtrType>(getPtrType(valueType, "PtrType")); + return dynamicCast<PtrType>(getPtrType(valueType, addrSpace, "PtrType")); } // Construct the type `Out<valueType>` @@ -306,9 +306,9 @@ InOutType* ASTBuilder::getInOutType(Type* valueType) return dynamicCast<InOutType>(getPtrType(valueType, "InOutType")); } -RefType* ASTBuilder::getRefType(Type* valueType) +RefType* ASTBuilder::getRefType(Type* valueType, AddressSpace addrSpace) { - return dynamicCast<RefType>(getPtrType(valueType, "RefType")); + return dynamicCast<RefType>(getPtrType(valueType, addrSpace, "RefType")); } ConstRefType* ASTBuilder::getConstRefType(Type* valueType) @@ -327,6 +327,12 @@ PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, char const* ptrTypeName) return as<PtrTypeBase>(getSpecializedBuiltinType(valueType, ptrTypeName)); } +PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, AddressSpace addrSpace, char const* ptrTypeName) +{ + Val* args[] = { valueType, getIntVal(getUInt64Type(), (IntegerLiteralValue)addrSpace) }; + return as<PtrTypeBase>(getSpecializedBuiltinType(makeArrayView(args), ptrTypeName)); +} + ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* elementCount) { if (!elementCount) diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index 5b4ec5538..52858a6b1 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -6,7 +6,7 @@ #include "slang-ast-support-types.h" #include "slang-ast-all.h" - +#include "slang-ir.h" #include "../core/slang-type-traits.h" #include "../core/slang-memory-arena.h" @@ -439,7 +439,7 @@ public: Type* getDiffInterfaceType() { return m_sharedASTBuilder->getDiffInterfaceType(); } // Construct the type `Ptr<valueType>`, where `Ptr` // is looked up as a builtin type. - PtrType* getPtrType(Type* valueType); + PtrType* getPtrType(Type* valueType, AddressSpace addrSpace); // Construct the type `Out<valueType>` OutType* getOutType(Type* valueType); @@ -448,7 +448,7 @@ public: InOutType* getInOutType(Type* valueType); // Construct the type `Ref<valueType>` - RefType* getRefType(Type* valueType); + RefType* getRefType(Type* valueType, AddressSpace addrSpace); // Construct the type `ConstRef<valueType>` ConstRefType* getConstRefType(Type* valueType); @@ -459,6 +459,7 @@ public: // Construct a pointer type like `Ptr<valueType>`, but where // the actual type name for the pointer type is given by `ptrTypeName` PtrTypeBase* getPtrType(Type* valueType, char const* ptrTypeName); + PtrTypeBase* getPtrType(Type* valueType, AddressSpace addrSpace, char const* ptrTypeName); ArrayExpressionType* getArrayType(Type* elementType, IntVal* elementCount); diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 47cd68b9e..44585ee30 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -352,6 +352,66 @@ Type* NativeRefType::getValueType() return as<Type>(_getGenericTypeArg(this, 0)); } +Val* PtrTypeBase::getAddressSpace() +{ + return _getGenericTypeArg(this, 1); +} + +AddressSpace tryGetAddressSpaceValue(Val* addrSpaceVal) +{ + AddressSpace addrSpace = AddressSpace::Generic; + + if (auto cintVal = as<ConstantIntVal>(addrSpaceVal)) + { + addrSpace = (AddressSpace)(cintVal->getValue()); + } + return addrSpace; +} + +void maybePrintAddrSpaceOperand(StringBuilder& out, AddressSpace addrSpace) +{ + switch (addrSpace) + { + case AddressSpace::Generic: + case AddressSpace::UserPointer: + break; + case AddressSpace::GroupShared: + out << toSlice(", groupshared"); + break; + case AddressSpace::Global: + out << toSlice(", global"); + break; + case AddressSpace::ThreadLocal: + out << toSlice(", threadlocal"); + break; + case AddressSpace::Uniform: + out << toSlice(", uniform"); + break; + default: + break; + } +} + +void PtrType::_toTextOverride(StringBuilder& out) +{ + auto addrSpace = tryGetAddressSpaceValue(getAddressSpace()); + if (addrSpace == AddressSpace::Generic) + out << toSlice("Addr<") << getValueType(); + else + out << toSlice("Ptr<") << getValueType(); + maybePrintAddrSpaceOperand(out, addrSpace); + out << toSlice(">"); +} + +void RefType::_toTextOverride(StringBuilder& out) +{ + out << toSlice("Ref<") << getValueType(); + auto addressSpaceVal = getAddressSpace(); + maybePrintAddrSpaceOperand(out, tryGetAddressSpaceValue(addressSpaceVal)); + out << toSlice(">"); +} + + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! NamedExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void NamedExpressionType::_toTextOverride(StringBuilder& out) diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 0d1a89860..945d051ba 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -539,6 +539,8 @@ class PtrTypeBase : public BuiltinType // Get the type of the pointed-to value. Type* getValueType(); + + Val* getAddressSpace(); }; class NoneType : public BuiltinType @@ -555,6 +557,8 @@ class NullPtrType : public BuiltinType class PtrType : public PtrTypeBase { SLANG_AST_CLASS(PtrType) + + void _toTextOverride(StringBuilder& out); }; // A GPU pointer type into global memory. @@ -599,6 +603,7 @@ class RefTypeBase : public ParamDirectionType class RefType : public RefTypeBase { SLANG_AST_CLASS(RefType) + void _toTextOverride(StringBuilder& out); }; // The type for an `constref` parameter, e.g., `constref T` diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef index b13571d18..59e79a07a 100644 --- a/source/slang/slang-capabilities.capdef +++ b/source/slang/slang-capabilities.capdef @@ -137,6 +137,7 @@ alias any_gfx_target = hlsl | metal | glsl | spirv; alias any_cpp_target = cpp | cuda; alias cpp_cuda = cpp | cuda; +alias cpp_cuda_spirv = cpp | cuda | spirv; alias cpp_cuda_glsl_spirv = cpp | cuda | glsl | spirv; alias cpp_cuda_glsl_hlsl = cpp | cuda | glsl | hlsl; alias cpp_cuda_glsl_hlsl_spirv = cpp | cuda | glsl | hlsl | spirv; diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 111f4e465..00625d5f4 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -1024,7 +1024,6 @@ namespace Slang return false; if (as<RefType>(toType) && !fromExpr->type.isLeftValue) return false; - ConversionCost subCost = kConversionCost_GetRef; MakeRefExpr* refExpr = nullptr; diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index fd03e5e87..96e0a95d0 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -4425,7 +4425,7 @@ namespace Slang expr->base = CheckProperType(expr->base); if (as<ErrorType>(expr->base.type)) expr->type = expr->base.type; - auto ptrType = m_astBuilder->getPtrType(expr->base.type); + auto ptrType = m_astBuilder->getPtrType(expr->base.type, AddressSpace::UserPointer); expr->type = m_astBuilder->getTypeType(ptrType); return expr; } diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 5c508a541..bf478a17d 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -885,7 +885,7 @@ DIAGNOSTIC(99999, Internal, internalCompilerError, "Slang internal compiler erro DIAGNOSTIC(99999, Error, compilationAborted, "Slang compilation aborted due to internal error") DIAGNOSTIC(99999, Error, compilationAbortedDueToException, "Slang compilation aborted due to an exception of $0: $1") DIAGNOSTIC(99999, Internal, serialDebugVerificationFailed, "Verification of serial debug information failed.") -DIAGNOSTIC(99999, Internal, spirvValidationFailed, "Validation of generated SPIR-V failed.") +DIAGNOSTIC(99999, Internal, spirvValidationFailed, "Validation of generated SPIR-V failed. SPIRV generated: \n$0") DIAGNOSTIC(99999, Internal, noBlocksOrIntrinsic, "no blocks found for function definition, is there a '$0' intrinsic missing?") diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 7ce2c7900..dd5ef8c51 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -1814,7 +1814,7 @@ void CLikeSourceEmitter::emitRateQualifiers(IRInst* value) const auto rate = value->getRate(); if (rate) { - emitRateQualifiersAndAddressSpaceImpl(rate, -1); + emitRateQualifiersAndAddressSpaceImpl(rate, AddressSpace::Generic); } } @@ -1822,8 +1822,8 @@ void CLikeSourceEmitter::emitRateQualifiersAndAddressSpace(IRInst* value) { const auto rate = value->getRate(); const auto ptrTy = composeGetters<IRPtrTypeBase>(value, &IRInst::getDataType); - const auto addressSpace = ptrTy ? ptrTy->getAddressSpace() : -1; - if (rate || addressSpace != -1) + const auto addressSpace = ptrTy ? ptrTy->getAddressSpace() : AddressSpace::Generic; + if (rate || addressSpace != AddressSpace::Generic) { emitRateQualifiersAndAddressSpaceImpl(rate, addressSpace); } diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index 07cb4a0bc..c866456a2 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -488,7 +488,7 @@ public: virtual void beforeComputeEmitActions(IRModule* module) { SLANG_UNUSED(module); }; - virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) { SLANG_UNUSED(rate); SLANG_UNUSED(addressSpace); } + virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, AddressSpace addressSpace) { SLANG_UNUSED(rate); SLANG_UNUSED(addressSpace); } virtual void emitSemanticsImpl(IRInst* inst, bool allowOffsetLayout) { SLANG_UNUSED(inst); SLANG_UNUSED(allowOffsetLayout); } virtual void emitSimpleFuncParamImpl(IRParam* param); virtual void emitSimpleFuncParamsImpl(IRFunc* func); diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp index f4b45b7aa..ab7ee8f58 100644 --- a/source/slang/slang-emit-cuda.cpp +++ b/source/slang/slang-emit-cuda.cpp @@ -738,7 +738,7 @@ void CUDASourceEmitter::emitSimpleTypeImpl(IRType* type) } } -void CUDASourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, [[maybe_unused]] IRIntegerValue addressSpace) +void CUDASourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, [[maybe_unused]] AddressSpace addressSpace) { if (as<IRGroupSharedRate>(rate)) { diff --git a/source/slang/slang-emit-cuda.h b/source/slang/slang-emit-cuda.h index 5d76c4ef3..ac7a33302 100644 --- a/source/slang/slang-emit-cuda.h +++ b/source/slang/slang-emit-cuda.h @@ -67,7 +67,7 @@ protected: virtual void emitParameterGroupImpl(IRGlobalParam* varDecl, IRUniformParameterGroupType* type) SLANG_OVERRIDE; virtual void emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPointDecoration* entryPointDecor) SLANG_OVERRIDE; - virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) SLANG_OVERRIDE; + virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, AddressSpace addressSpace) SLANG_OVERRIDE; virtual void emitSemanticsImpl(IRInst* inst, bool allowOffsetLayout) SLANG_OVERRIDE; virtual void emitSimpleFuncImpl(IRFunc* func) SLANG_OVERRIDE; virtual void emitSimpleFuncParamsImpl(IRFunc* func) SLANG_OVERRIDE; diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index eea672938..0fb868342 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -2638,9 +2638,9 @@ void GLSLSourceEmitter::emitSimpleTypeImpl(IRType* type) SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled type"); } -void GLSLSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) +void GLSLSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, AddressSpace addressSpace) { - if(addressSpace == SpvStorageClassTaskPayloadWorkgroupEXT) + if(addressSpace == (AddressSpace)SpvStorageClassTaskPayloadWorkgroupEXT) { m_writer->emit("taskPayloadSharedEXT "); } diff --git a/source/slang/slang-emit-glsl.h b/source/slang/slang-emit-glsl.h index efd3ded75..1911749ec 100644 --- a/source/slang/slang-emit-glsl.h +++ b/source/slang/slang-emit-glsl.h @@ -33,7 +33,7 @@ protected: virtual void emitFrontMatterImpl(TargetRequest* targetReq) SLANG_OVERRIDE; - virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) SLANG_OVERRIDE; + virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, AddressSpace addressSpace) SLANG_OVERRIDE; virtual void emitInterpolationModifiersImpl(IRInst* varInst, IRType* valueType, IRVarLayout* layout) SLANG_OVERRIDE; virtual void emitPackOffsetModifier(IRInst* varInst, IRType* valueType, IRPackOffsetDecoration* decoration) SLANG_OVERRIDE; diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp index 4d3830969..d3757f534 100644 --- a/source/slang/slang-emit-hlsl.cpp +++ b/source/slang/slang-emit-hlsl.cpp @@ -1096,7 +1096,7 @@ void HLSLSourceEmitter::emitSimpleTypeImpl(IRType* type) } } -void HLSLSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, [[maybe_unused]] IRIntegerValue addressSpace) +void HLSLSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, [[maybe_unused]] AddressSpace addressSpace) { if (as<IRGroupSharedRate>(rate)) { diff --git a/source/slang/slang-emit-hlsl.h b/source/slang/slang-emit-hlsl.h index abc673a3d..2137a6d3d 100644 --- a/source/slang/slang-emit-hlsl.h +++ b/source/slang/slang-emit-hlsl.h @@ -35,7 +35,7 @@ protected: virtual void emitFrontMatterImpl(TargetRequest* targetReq) SLANG_OVERRIDE; - virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) SLANG_OVERRIDE; + virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, AddressSpace addressSpace) SLANG_OVERRIDE; virtual void emitSemanticsImpl(IRInst* inst, bool allowOffsets) SLANG_OVERRIDE; virtual void emitSimpleFuncParamImpl(IRParam* param) SLANG_OVERRIDE; virtual void emitInterpolationModifiersImpl(IRInst* varInst, IRType* valueType, IRVarLayout* layout) SLANG_OVERRIDE; diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp index e7df29e0c..302e2b6b7 100644 --- a/source/slang/slang-emit-metal.cpp +++ b/source/slang/slang-emit-metal.cpp @@ -984,7 +984,7 @@ void MetalSourceEmitter::emitPackOffsetModifier(IRInst* varInst, IRType* valueTy // We emit packoffset as a semantic in `emitSemantic`, so nothing to do here. } -void MetalSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) +void MetalSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, AddressSpace addressSpace) { if (as<IRGroupSharedRate>(rate)) { @@ -992,7 +992,7 @@ void MetalSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRI return; } - switch ((AddressSpace)addressSpace) + switch (addressSpace) { case AddressSpace::GroupShared: m_writer->emit("threadgroup "); diff --git a/source/slang/slang-emit-metal.h b/source/slang/slang-emit-metal.h index 32557bf27..6460b076a 100644 --- a/source/slang/slang-emit-metal.h +++ b/source/slang/slang-emit-metal.h @@ -31,7 +31,7 @@ protected: virtual void emitFrontMatterImpl(TargetRequest* targetReq) SLANG_OVERRIDE; - virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) SLANG_OVERRIDE; + virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, AddressSpace addressSpace) SLANG_OVERRIDE; virtual void emitSemanticsImpl(IRInst* inst, bool allowOffsets) SLANG_OVERRIDE; virtual void emitSimpleFuncParamImpl(IRParam* param) SLANG_OVERRIDE; virtual void emitPostDeclarationAttributesForType(IRInst* type) SLANG_OVERRIDE; diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index f5bb21c00..cb9eb0ae9 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1108,7 +1108,17 @@ struct SPIRVEmitContext // If we have seen this before, return the memoized instruction if (SpvInst** memoized = m_spvTypeInsts.tryGetValue(key)) + { + // There could be another different slang IR inst that translates to + // the same spir-v inst. + // For example, both Ptr<T> and Ref<T> translates to the same pointer + // type in spirv. + // In this case we need to make sure we also + // register `inst` to map it to the memoized spir-v inst. + if (irInst) + m_mapIRInstToSpvInst.addIfNotExists(irInst, *memoized); return *memoized; + } // Otherwise, we can construct our instruction and record the result InstConstructScope scopeInst(this, opcode, irInst); @@ -1213,6 +1223,19 @@ struct SPIRVEmitContext return m_NonSemanticDebugPrintfExtInst; } + SpvStorageClass addressSpaceToStorageClass(AddressSpace addrSpace) + { + switch (addrSpace) + { + case AddressSpace::Generic: + return SpvStorageClassMax; + case AddressSpace::UserPointer: + return SpvStorageClassPhysicalStorageBuffer; + default: + return (SpvStorageClass)addrSpace; + } + } + // Now that we've gotten the core infrastructure out of the way, // let's start looking at emitting some instructions that make // up a SPIR-V module. @@ -1398,7 +1421,7 @@ struct SPIRVEmitContext auto ptrType = as<IRPtrTypeBase>(inst); SLANG_ASSERT(ptrType); if (ptrType->hasAddressSpace()) - storageClass = (SpvStorageClass)ptrType->getAddressSpace(); + storageClass = addressSpaceToStorageClass(ptrType->getAddressSpace()); if (storageClass == SpvStorageClassStorageBuffer) ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_storage_buffer_storage_class")); if (storageClass == SpvStorageClassPhysicalStorageBuffer) @@ -1411,14 +1434,24 @@ struct SPIRVEmitContext && as<IRStructType>(valueType) && storageClass == SpvStorageClassPhysicalStorageBuffer); SpvId valueTypeId; - if (useForwardDeclaration) + if (as<IRVoidType>(valueType)) { - valueTypeId = getIRInstSpvID(valueType); + // Emit void* as uint*. + IRBuilder builder(valueType); + builder.setInsertBefore(valueType); + valueTypeId = getID(ensureInst(builder.getUIntType())); } else { - auto spvValueType = ensureInst(valueType); - valueTypeId = getID(spvValueType); + if (useForwardDeclaration) + { + valueTypeId = getIRInstSpvID(valueType); + } + else + { + auto spvValueType = ensureInst(valueType); + valueTypeId = getID(spvValueType); + } } auto resultSpvType = emitOpTypePointer( @@ -3339,7 +3372,7 @@ struct SPIRVEmitContext if (auto ptrType = as<IRPtrTypeBase>(globalInst->getDataType())) { auto addrSpace = ptrType->getAddressSpace(); - if (addrSpace != SpvStorageClassInput && addrSpace != SpvStorageClassOutput) + if (addrSpace != AddressSpace(SpvStorageClassInput) && addrSpace != AddressSpace(SpvStorageClassOutput)) continue; } } @@ -4042,7 +4075,7 @@ struct SPIRVEmitContext if (!ptrType) return; auto addrSpace = ptrType->getAddressSpace(); - if (addrSpace == SpvStorageClassInput) + if (addrSpace == AddressSpace(SpvStorageClassInput)) { if (isIntegralScalarOrCompositeType(ptrType->getValueType())) { @@ -4333,10 +4366,10 @@ struct SPIRVEmitContext void maybeEmitPointerDecoration(SpvInst* varInst, IRInst* inst) { - auto ptrType = as<IRPtrType>(inst->getDataType()); + auto ptrType = as<IRPtrType>(unwrapArray(inst->getDataType())); if (!ptrType) return; - if (ptrType->getAddressSpace() == SpvStorageClassPhysicalStorageBuffer) + if (addressSpaceToStorageClass(ptrType->getAddressSpace()) == SpvStorageClassPhysicalStorageBuffer) { // If inst has a pointer type with PhysicalStorageBuffer address space, // emit AliasedPointer decoration. @@ -4351,10 +4384,10 @@ struct SPIRVEmitContext { // If the pointee type is a pointer with StorageBuffer address space, // we also want to emit AliasedPointer decoration. - ptrType = as<IRPtrType>(ptrType->getValueType()); + ptrType = as<IRPtrType>(unwrapArray(ptrType->getValueType())); if (!ptrType) return; - if (ptrType->getAddressSpace() == SpvStorageClassPhysicalStorageBuffer) + if (addressSpaceToStorageClass(ptrType->getAddressSpace()) == SpvStorageClassPhysicalStorageBuffer) { emitOpDecorate( getSection(SpvLogicalSectionID::Annotations), @@ -4996,7 +5029,7 @@ struct SPIRVEmitContext SpvInst* emitLoad(SpvInstParent* parent, IRLoad* inst) { auto ptrType = as<IRPtrTypeBase>(inst->getPtr()->getDataType()); - if (ptrType && ptrType->getAddressSpace() == SpvStorageClassPhysicalStorageBuffer) + if (ptrType && addressSpaceToStorageClass(ptrType->getAddressSpace()) == SpvStorageClassPhysicalStorageBuffer) { IRSizeAndAlignment sizeAndAlignment; getNaturalSizeAndAlignment(m_targetProgram->getOptionSet(), ptrType->getValueType(), &sizeAndAlignment); @@ -5011,7 +5044,7 @@ struct SPIRVEmitContext SpvInst* emitStore(SpvInstParent* parent, IRStore* inst) { auto ptrType = as<IRPtrTypeBase>(inst->getPtr()->getDataType()); - if (ptrType && ptrType->getAddressSpace() == SpvStorageClassPhysicalStorageBuffer) + if (ptrType && addressSpaceToStorageClass(ptrType->getAddressSpace()) == SpvStorageClassPhysicalStorageBuffer) { IRSizeAndAlignment sizeAndAlignment; getNaturalSizeAndAlignment(m_targetProgram->getOptionSet(), ptrType->getValueType(), &sizeAndAlignment); @@ -5077,7 +5110,7 @@ struct SPIRVEmitContext return emitOpVectorShuffle(parent, inst, inst->getFullType(), inst->getBase(), inst->getSource(), shuffleIndices.getArrayView()); } - IRPtrTypeBase* getPtrTypeWithAddressSpace(IRPtrTypeBase* ptrTypeWithNoAddressSpace, IRIntegerValue addressSpace) + IRPtrTypeBase* getPtrTypeWithAddressSpace(IRPtrTypeBase* ptrTypeWithNoAddressSpace, AddressSpace addressSpace) { // If it's already ok, return as is if(ptrTypeWithNoAddressSpace->getAddressSpace() == addressSpace) @@ -5104,7 +5137,7 @@ struct SPIRVEmitContext parent, inst, // Make sure the resulting pointer has the correct storage class - getPtrTypeWithAddressSpace(cast<IRPtrTypeBase>(inst->getDataType()), storageClass), + getPtrTypeWithAddressSpace(cast<IRPtrTypeBase>(inst->getDataType()), AddressSpace(storageClass)), inst->getOperand(0), makeArray(emitIntConstant(0, builder.getIntType()), ensureInst(inst->getOperand(1))) ); diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 9e21ccbfd..03d1b932c 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1718,9 +1718,13 @@ SlangResult emitSPIRVForEntryPointsDirectly( { if (SLANG_FAILED(compiler->validate((uint32_t*)spirv.getBuffer(), int(spirv.getCount()/4)))) { + String err; + String dis; + disassembleSPIRV(spirv, err, dis); codeGenContext->getSink()->diagnoseWithoutSourceView( SourceLoc{}, - Diagnostics::spirvValidationFailed + Diagnostics::spirvValidationFailed, + dis ); } } diff --git a/source/slang/slang-ir-any-value-marshalling.cpp b/source/slang/slang-ir-any-value-marshalling.cpp index 13059a84a..f09294aa9 100644 --- a/source/slang/slang-ir-any-value-marshalling.cpp +++ b/source/slang/slang-ir-any-value-marshalling.cpp @@ -145,13 +145,10 @@ namespace Slang case kIROp_VectorType: { auto vectorType = static_cast<IRVectorType*>(dataType); - auto elementType = vectorType->getElementType(); auto elementCount = getIntVal(vectorType->getElementCount()); - auto elementPtrType = builder->getPtrType(elementType); for (IRIntegerValue i = 0; i < elementCount; i++) { auto elementAddr = builder->emitElementAddress( - elementPtrType, concreteTypedVar, builder->getIntValue(builder->getIntType(), i)); emitMarshallingCode(builder, context, elementAddr); @@ -161,20 +158,16 @@ namespace Slang case kIROp_MatrixType: { auto matrixType = static_cast<IRMatrixType*>(dataType); - auto elementType = matrixType->getElementType(); auto colCount = getIntVal(matrixType->getColumnCount()); auto rowCount = getIntVal(matrixType->getRowCount()); - auto rowVecType = builder->getVectorType(elementType, matrixType->getRowCount()); for (IRIntegerValue i = 0; i < colCount; i++) { auto col = builder->emitElementAddress( - builder->getPtrType(rowVecType), concreteTypedVar, builder->getIntValue(builder->getIntType(), i)); for (IRIntegerValue j = 0; j < rowCount; j++) { auto element = builder->emitElementAddress( - builder->getPtrType(elementType), col, builder->getIntValue(builder->getIntType(), j)); emitMarshallingCode(builder, context, element); @@ -198,11 +191,9 @@ namespace Slang case kIROp_ArrayType: { auto arrayType = cast<IRArrayType>(dataType); - auto elementPtrType = builder->getPtrType(arrayType->getElementType()); for (IRIntegerValue i = 0; i < getIntVal(arrayType->getElementCount()); i++) { auto fieldAddr = builder->emitElementAddress( - elementPtrType, concreteTypedVar, builder->getIntValue(builder->getIntType(), i)); emitMarshallingCode(builder, context, fieldAddr); diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index 511055713..9fe4ec70b 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -1117,14 +1117,9 @@ IRInst* emitIndexedStoreAddressForVar( const List<IndexTrackingInfo>& defBlockIndices) { IRInst* storeAddr = localVar; - IRType* currType = as<IRPtrTypeBase>(localVar->getDataType())->getValueType(); - for (auto& index : defBlockIndices) { - currType = as<IRArrayType>(currType)->getElementType(); - storeAddr = builder->emitElementAddress( - builder->getPtrType(currType), storeAddr, index.primalCountParam); } @@ -1141,11 +1136,9 @@ IRInst* emitIndexedLoadAddressForVar( const List<IndexTrackingInfo>& useBlockIndices) { IRInst* loadAddr = localVar; - IRType* currType = as<IRPtrTypeBase>(localVar->getDataType())->getValueType(); for (auto index : defBlockIndices) { - currType = as<IRArrayType>(currType)->getElementType(); if (useBlockIndices.contains(index)) { // If the use-block is under the same region, use the @@ -1154,7 +1147,6 @@ IRInst* emitIndexedLoadAddressForVar( auto diffCounterCurrValue = index.diffCountParam; loadAddr = builder->emitElementAddress( - builder->getPtrType(currType), loadAddr, diffCounterCurrValue); } @@ -1173,7 +1165,6 @@ IRInst* emitIndexedLoadAddressForVar( builder->getIntValue(builder->getIntType(), 1)); loadAddr = builder->emitElementAddress( - builder->getPtrType(currType), loadAddr, primalCounterLastValue); } diff --git a/source/slang/slang-ir-composite-reg-to-mem.cpp b/source/slang/slang-ir-composite-reg-to-mem.cpp index 243a0e2b0..512683938 100644 --- a/source/slang/slang-ir-composite-reg-to-mem.cpp +++ b/source/slang/slang-ir-composite-reg-to-mem.cpp @@ -36,7 +36,6 @@ namespace Slang if (getElementUser->getOperands() == use) { newAddr = builder.emitElementAddress( - builder.getPtrType(user->getFullType()), addr, getElementUser->getIndex()); } diff --git a/source/slang/slang-ir-explicit-global-context.cpp b/source/slang/slang-ir-explicit-global-context.cpp index 9d3f09712..56fd62883 100644 --- a/source/slang/slang-ir-explicit-global-context.cpp +++ b/source/slang/slang-ir-explicit-global-context.cpp @@ -258,7 +258,7 @@ struct IntroduceExplicitGlobalContextPass if (kind == GlobalObjectKind::GlobalVar) { auto ptrType = as<IRPtrTypeBase>(type); - if (ptrType->getAddressSpace() == (IRIntegerValue)AddressSpace::GroupShared) + if (ptrType->getAddressSpace() == AddressSpace::GroupShared) { fieldDataType = ptrType; needDereference = true; diff --git a/source/slang/slang-ir-glsl-liveness.cpp b/source/slang/slang-ir-glsl-liveness.cpp index af64df4f4..64d41490a 100644 --- a/source/slang/slang-ir-glsl-liveness.cpp +++ b/source/slang/slang-ir-glsl-liveness.cpp @@ -133,7 +133,7 @@ void GLSLLivenessContext::_replaceMarker(IRLiveRangeMarker* markerInst) IRType* paramTypes[] = { - m_builder.getRefType(referencedType), ///< Use a reference to the referenced type + m_builder.getRefType(referencedType, AddressSpace::Generic), ///< Use a reference to the referenced type m_spirvIntLiteralType, ///< The size type }; diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index a5493d6ea..3aa7d1f64 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3490,10 +3490,11 @@ public: IRPtrType* getPtrType(IRType* valueType); IROutType* getOutType(IRType* valueType); IRInOutType* getInOutType(IRType* valueType); - IRRefType* getRefType(IRType* valueType); + IRRefType* getRefType(IRType* valueType, AddressSpace addrSpace); IRConstRefType* getConstRefType(IRType* valueType); IRPtrTypeBase* getPtrType(IROp op, IRType* valueType); IRPtrType* getPtrType(IROp op, IRType* valueType, IRIntegerValue addressSpace); + IRPtrType* getPtrType(IROp op, IRType* valueType, IRInst* addressSpace); IRPtrType* getPtrType(IROp op, IRType* valueType, AddressSpace addressSpace) { return getPtrType(op, valueType, (IRIntegerValue)addressSpace); } IRPtrType* getPtrType(IRType* valueType, AddressSpace addressSpace) { return getPtrType(kIROp_PtrType, valueType, (IRIntegerValue)addressSpace); } diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index 5eefc121e..cc24a0e81 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -578,7 +578,7 @@ namespace Slang { if (auto ptrType = as<IRPtrType>(globalInst)) { - if (ptrType->getAddressSpace() == SpvStorageClassPhysicalStorageBuffer) + if (ptrType->getAddressSpace() == AddressSpace::UserPointer) elementType = ptrType->getValueType(); } } diff --git a/source/slang/slang-ir-simplify-for-emit.cpp b/source/slang/slang-ir-simplify-for-emit.cpp index 7fd609011..953d9f68a 100644 --- a/source/slang/slang-ir-simplify-for-emit.cpp +++ b/source/slang/slang-ir-simplify-for-emit.cpp @@ -73,7 +73,6 @@ struct SimplifyForEmitContext : public InstPassBase for (UInt i = 0; i < makeArray->getOperandCount(); i++) { auto elementAddr = builder.emitElementAddress( - builder.getPtrType(arrayType->getElementType()), store->getPtr(), builder.getIntValue(builder.getIntType(), (IRIntegerValue)i)); builder.emitStore(elementAddr, makeArray->getOperand(i)); @@ -107,7 +106,6 @@ struct SimplifyForEmitContext : public InstPassBase for (IRIntegerValue i = 0; i < arraySize->getValue(); i++) { auto elementAddr = builder.emitElementAddress( - builder.getPtrType(arrayType->getElementType()), store->getPtr(), builder.getIntValue(builder.getIntType(), i)); builder.emitStore(elementAddr, makeArray->getOperand(0)); diff --git a/source/slang/slang-ir-specialize-address-space.cpp b/source/slang/slang-ir-specialize-address-space.cpp index 1d899e240..24eee3d76 100644 --- a/source/slang/slang-ir-specialize-address-space.cpp +++ b/source/slang/slang-ir-specialize-address-space.cpp @@ -331,9 +331,12 @@ namespace Slang auto ptrType = as<IRPtrTypeBase>(inst->getDataType()); if (ptrType) { - IRBuilder builder(inst); - auto newType = builder.getPtrType(ptrType->getOp(), ptrType->getValueType(), addrSpace); - setDataType(inst, newType); + if (ptrType->getAddressSpace() != addrSpace) + { + IRBuilder builder(inst); + auto newType = builder.getPtrType(ptrType->getOp(), ptrType->getValueType(), addrSpace); + setDataType(inst, newType); + } } } } diff --git a/source/slang/slang-ir-specialize-address-space.h b/source/slang/slang-ir-specialize-address-space.h index 300b6129c..61b191cf2 100644 --- a/source/slang/slang-ir-specialize-address-space.h +++ b/source/slang/slang-ir-specialize-address-space.h @@ -1,11 +1,13 @@ // slang-ir-specialize-address-space.h #pragma once +#include <cinttypes> + namespace Slang { struct IRModule; struct IRInst; - enum class AddressSpace; + enum class AddressSpace : uint64_t; struct AddressSpaceSpecializationContext { diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 5713b9639..5c9e1ad24 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -1996,9 +1996,6 @@ struct SpecializationContext auto index = inst->getIndex(); auto val = wrapInst->getWrappedValue(); - auto ptrType = cast<IRPtrTypeBase>(val->getDataType()); - auto arrayType = cast<IRArrayTypeBase>(ptrType->getValueType()); - auto elementType = arrayType->getElementType(); auto resultType = inst->getFullType(); @@ -2013,8 +2010,7 @@ struct SpecializationContext slotOperands.add(wrapInst->getSlotOperand(ii)); } - auto elementPtrType = builder.getPtrType(ptrType->getOp(), elementType); - auto newElementAddr = builder.emitElementAddress(elementPtrType, val, index); + auto newElementAddr = builder.emitElementAddress(val, index); auto newWrapExistentialInst = builder.emitWrapExistential( resultType, newElementAddr, slotOperandCount, slotOperands.getArrayView().getBuffer()); diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index 27a186ee9..528fe2331 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -783,7 +783,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase IRBuilder builder(inst); builder.setInsertBefore(inst); auto newPtrType = builder.getPtrType( - oldPtrType->getOp(), translateToStorageBufferPointer(oldPtrType->getValueType()), SpvStorageClassFunction); + oldPtrType->getOp(), oldPtrType->getValueType(), SpvStorageClassFunction); inst->setFullType(newPtrType); addUsersToWorkList(inst); } @@ -800,12 +800,12 @@ struct SPIRVLegalizationContext : public SourceEmitterBase return; if (!oldPtrType->hasAddressSpace()) { - SpvStorageClass addressSpace = (SpvStorageClass)-1; + AddressSpace addressSpace = AddressSpace::Generic; if (block == func->getFirstBlock()) { // A pointer typed function parameter should always be in the storage buffer address space. - addressSpace = SpvStorageClassPhysicalStorageBuffer; + addressSpace = AddressSpace::UserPointer; } else { @@ -816,19 +816,19 @@ struct SPIRVLegalizationContext : public SourceEmitterBase auto argPtrType = as<IRPtrType>(arg->getDataType()); if (argPtrType->hasAddressSpace()) { - if (addressSpace == (SpvStorageClass)-1) - addressSpace = (SpvStorageClass)argPtrType->getAddressSpace(); + if (addressSpace == AddressSpace::Generic) + addressSpace = argPtrType->getAddressSpace(); else if (addressSpace != argPtrType->getAddressSpace()) m_sharedContext->m_sink->diagnose(inst, Diagnostics::inconsistentPointerAddressSpace, inst); } } } - if (addressSpace != (SpvStorageClass)-1) + if (addressSpace != AddressSpace::Generic) { IRBuilder builder(inst); builder.setInsertBefore(inst); auto newPtrType = builder.getPtrType( - oldPtrType->getOp(), translateToStorageBufferPointer(oldPtrType->getValueType()), SpvStorageClassPhysicalStorageBuffer); + oldPtrType->getOp(), oldPtrType->getValueType(), AddressSpace::UserPointer); inst->setFullType(newPtrType); addUsersToWorkList(inst); } @@ -842,7 +842,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase return; // Update the pointer value type with storage-buffer-address-space-decorated types. - auto newPtrValueType = translateToStorageBufferPointer(oldPtrType->getValueType()); + auto newPtrValueType = oldPtrType->getValueType(); if (newPtrValueType != oldPtrType->getValueType()) { IRBuilder builder(inst); @@ -900,7 +900,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase IRBuilder builder(m_sharedContext->m_irModule); builder.setInsertBefore(inst); auto newPtrType = - builder.getPtrType(oldPtrType->getOp(), translateToStorageBufferPointer(oldPtrType->getValueType()), storageClass); + builder.getPtrType(oldPtrType->getOp(), oldPtrType->getValueType(), storageClass); inst->setFullType(newPtrType); addUsersToWorkList(inst); return; @@ -923,7 +923,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase IRBuilder builder(m_sharedContext->m_irModule); builder.setInsertBefore(inst); auto qualPtrType = builder.getPtrType( - ptrType->getOp(), translateToStorageBufferPointer(ptrType->getValueType()), snippet->resultStorageClass); + ptrType->getOp(), ptrType->getValueType(), snippet->resultStorageClass); List<IRInst*> args; for (UInt i = 0; i < inst->getArgCount(); i++) args.add(inst->getArg(i)); @@ -1011,7 +1011,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } // If we reach here, we need to allocate a temp var. - auto tempVar = builder.emitVar(translateToStorageBufferPointer(ptrType->getValueType())); + auto tempVar = builder.emitVar(ptrType->getValueType()); auto load = builder.emitLoad(arg); builder.emitStore(tempVar, load); newArgs.add(tempVar); @@ -1021,7 +1021,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase if (writeBacks.getCount()) { auto newCall = builder.emitCallInst( - translateToStorageBufferPointer(inst->getFullType()), + inst->getFullType(), inst->getCallee(), newArgs); for (auto wb : writeBacks) @@ -1033,15 +1033,6 @@ struct SPIRVLegalizationContext : public SourceEmitterBase inst->removeAndDeallocate(); addUsersToWorkList(newCall); } - else - { - // If we reach here, we have determined that all arguments passed as a pointer - // are actual memory objects, so they can be passed in as-is. - // We still need to make sure the callee is specialized to the address-space - // of the arguments, this is done in a separate specialization pass. - - translatePtrResultType(inst); - } } Dictionary<IRInst*, IRInst*> m_mapArrayValueToVar; @@ -1074,7 +1065,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase builder.setInsertBefore(inst); else setInsertAfterOrdinaryInst(&builder, x); - y = builder.emitVar(translateToStorageBufferPointer(x->getDataType()), SpvStorageClassFunction); + y = builder.emitVar(x->getDataType(), SpvStorageClassFunction); builder.emitStore(y, x); if (x->getParent()->getOp() != kIROp_Module) m_mapArrayValueToVar.set(x, y); @@ -1101,7 +1092,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase builder.setInsertBefore(gepInst); auto newPtrType = builder.getPtrType( oldResultType->getOp(), - translateToStorageBufferPointer(oldResultType->getValueType()), + oldResultType->getValueType(), ptrType->getAddressSpace()); IRInst* args[2] = { base, index }; auto newInst = @@ -1154,7 +1145,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase IRBuilder builder(offsetPtrInst); builder.setInsertBefore(offsetPtrInst); auto newResultType = builder.getPtrType(resultPtrType->getOp(), - translateToStorageBufferPointer(resultPtrType->getValueType()), + resultPtrType->getValueType(), ptrOperandType->getAddressSpace()); auto newInst = builder.replaceOperand(&offsetPtrInst->typeUse, newResultType); addUsersToWorkList(newInst); @@ -1174,7 +1165,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase builder.setInsertBefore(loadInst); IRInst* args[] = { sb, index }; auto addrInst = builder.emitIntrinsicInst( - builder.getPtrType(kIROp_PtrType, translateToStorageBufferPointer(loadInst->getFullType()), getStorageBufferStorageClass()), + builder.getPtrType(kIROp_PtrType, loadInst->getFullType(), getStorageBufferStorageClass()), kIROp_RWStructuredBufferGetElementPtr, 2, args); @@ -1357,7 +1348,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase return; auto oldResultType = as<IRPtrTypeBase>(inst->getDataType()); auto oldValueType = oldResultType->getValueType(); - auto newValueType = translateToStorageBufferPointer(oldValueType); + auto newValueType = oldValueType; if (oldValueType != newValueType || oldResultType->getAddressSpace() != ptrType->getAddressSpace()) { @@ -1381,7 +1372,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase auto ptrType = as<IRPtrType>(inst->getDataType()); if (!ptrType) return; - auto newPtrType = translateToStorageBufferPointer(ptrType); + auto newPtrType = ptrType; if (newPtrType == ptrType) return; IRBuilder builder(inst); @@ -1890,79 +1881,19 @@ struct SPIRVLegalizationContext : public SourceEmitterBase addToWorkList(branch->getOperand(0)); } - // If type is pointer type and does not have an address space, make it a - // storage buffer pointer. - IRType* translateToStorageBufferPointer(IRType* type) - { - if (auto ptrType = as<IRPtrType>(type)) - { - auto oldValueType = ptrType->getValueType(); - auto newValueType = translateToStorageBufferPointer(oldValueType); - if (oldValueType != newValueType || !ptrType->hasAddressSpace()) - { - IRBuilder builder(m_module); - IRIntegerValue addressSpace = (ptrType->hasAddressSpace() ? ptrType->getAddressSpace() : IRIntegerValue(SpvStorageClassPhysicalStorageBuffer)); - return builder.getPtrType(ptrType->getOp(), newValueType, addressSpace); - } - return ptrType; - } - else if (auto arrayTypeBase = as<IRArrayTypeBase>(type)) - { - auto oldValueType = arrayTypeBase->getElementType(); - auto newValueType = translateToStorageBufferPointer(oldValueType); - if (oldValueType != newValueType) - { - IRBuilder builder(m_module); - return builder.getArrayTypeBase(arrayTypeBase->getOp(), newValueType, arrayTypeBase->getElementCount()); - } - return arrayTypeBase; - } - return type; - } - - void translatePtrResultType(IRInst* inst) - { - auto ptrType = as<IRPtrType>(inst->getDataType()); - if (!ptrType) - { - if (auto refType = as<IRRefType>(inst->getDataType())) - { - // Functions that return ref type should be treated as returning a pointer. - IRBuilder builder(inst); - ptrType = builder.getPtrType(refType->getValueType()); - } - } - auto newPtrType = translateToStorageBufferPointer(ptrType); - if (newPtrType == ptrType) - return; - IRBuilder builder(inst); - auto newInst = builder.replaceOperand(&inst->typeUse, newPtrType); - addUsersToWorkList(newInst); - } - void processPtrLit(IRInst* inst) { IRBuilder builder(inst); builder.setInsertBefore(inst); - auto newPtrType = translateToStorageBufferPointer(as<IRPtrType>(inst->getFullType())); + auto newPtrType = as<IRPtrType>(inst->getFullType()); auto newInst = builder.emitCastIntToPtr(newPtrType, builder.getIntValue(builder.getUInt64Type(), 0)); inst->replaceUsesWith(newInst); addUsersToWorkList(newInst); } - void processPtrCast(IRInst* cast) - { - translatePtrResultType(cast); - } - - void processLoad(IRInst* inst) - { - translatePtrResultType(inst); - } - void processStructField(IRStructField* field) { - auto newFieldType = translateToStorageBufferPointer(field->getFieldType()); + auto newFieldType = field->getFieldType(); if (newFieldType != field->getFieldType()) field->setFieldType(newFieldType); } @@ -2095,17 +2026,9 @@ struct SPIRVLegalizationContext : public SourceEmitterBase case kIROp_MakeOptionalNone: processConstructor(inst); break; - case kIROp_BitCast: - case kIROp_PtrCast: - case kIROp_CastIntToPtr: - processPtrCast(inst); - break; case kIROp_PtrLit: processPtrLit(inst); break; - case kIROp_Load: - processLoad(inst); - break; case kIROp_unconditionalBranch: processBranch(inst); break; @@ -2261,7 +2184,14 @@ struct SPIRVLegalizationContext : public SourceEmitterBase { // Don't assign address space to additional insts, since we should have // already assigned address space to them in earlier stages of legalization. - auto type = unwrapAttributedType(inst->getDataType()); + auto type = inst->getDataType(); + for (;;) + { + auto newType = (IRType*)unwrapAttributedType(type); + newType = unwrapArray(newType); + if (newType == type) break; + type = newType; + } if (!type) return AddressSpace::Generic; return getAddressSpaceFromVarType(type); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index ba7376c46..bb05a1c29 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2813,9 +2813,9 @@ namespace Slang return (IRInOutType*) getPtrType(kIROp_InOutType, valueType); } - IRRefType* IRBuilder::getRefType(IRType* valueType) + IRRefType* IRBuilder::getRefType(IRType* valueType, AddressSpace addrSpace) { - return (IRRefType*) getPtrType(kIROp_RefType, valueType); + return (IRRefType*) getPtrType(kIROp_RefType, valueType, addrSpace); } IRConstRefType* IRBuilder::getConstRefType(IRType* valueType) @@ -2840,8 +2840,13 @@ namespace Slang IRPtrType* IRBuilder::getPtrType(IROp op, IRType* valueType, IRIntegerValue addressSpace) { - IRInst* operands[] = {valueType, getIntValue(getIntType(), addressSpace)}; - return (IRPtrType*)getType(op, 2, operands); + return (IRPtrType*)getPtrType(op, valueType, getIntValue(getUInt64Type(), addressSpace)); + } + + IRPtrType* IRBuilder::getPtrType(IROp op, IRType* valueType, IRInst* addressSpace) + { + IRInst* operands[] = { valueType, addressSpace }; + return (IRPtrType*)getType(op, addressSpace ? 2 : 1, operands); } IRTextureTypeBase* IRBuilder::getTextureType(IRType* elementType, IRInst* shape, IRInst* isArray, IRInst* isMS, IRInst* sampleCount, IRInst* access, IRInst* isShadow, IRInst* isCombined, IRInst* format) @@ -4881,11 +4886,30 @@ namespace Slang return inst; } + IRType* maybePropagateAddressSpace(IRBuilder* builder, IRInst* basePtr, IRType* type) + { + if (auto basePtrType = as<IRPtrTypeBase>(basePtr->getDataType())) + { + if (auto resultPtrType = as<IRPtrTypeBase>(type)) + { + if (basePtrType->getAddressSpace() != resultPtrType->getAddressSpace()) + { + type = builder->getPtrType( + resultPtrType->getOp(), resultPtrType->getValueType(), basePtrType->getAddressSpace()); + } + } + } + return type; + } + IRInst* IRBuilder::emitFieldAddress( IRType* type, IRInst* base, IRInst* field) { + // Propagate pointer address space if it is available on base. + type = maybePropagateAddressSpace(this, base, type); + auto inst = createInst<IRFieldAddress>( this, kIROp_FieldAddress, @@ -4982,6 +5006,9 @@ namespace Slang IRInst* basePtr, IRInst* index) { + // Propagate pointer address space if it is available on base. + type = maybePropagateAddressSpace(this, basePtr, type); + auto inst = createInst<IRFieldAddress>( this, kIROp_GetElementPtr, @@ -5004,9 +5031,20 @@ namespace Slang IRInst* basePtr, IRInst* index) { + AddressSpace addrSpace = AddressSpace::Generic; + IRInst* valueType = nullptr; + auto basePtrType = unwrapAttributedType(basePtr->getDataType()); + if (auto ptrType = as<IRPtrTypeBase>(basePtrType)) + { + addrSpace = ptrType->getAddressSpace(); + valueType = ptrType->getValueType(); + } + else if (auto ptrLikeType = as<IRPointerLikeType>(basePtrType)) + { + valueType = ptrLikeType->getElementType(); + } IRType* type = nullptr; - auto basePtrType = as<IRPtrTypeBase>(basePtr->getDataType()); - auto valueType = unwrapAttributedType(basePtrType->getValueType()); + valueType = unwrapAttributedType(valueType); if (auto arrayType = as<IRArrayTypeBase>(valueType)) { type = arrayType->getElementType(); @@ -5028,7 +5066,7 @@ namespace Slang auto inst = createInst<IRGetElementPtr>( this, kIROp_GetElementPtr, - getPtrType(type), + getPtrType(kIROp_PtrType, type, addrSpace), basePtr, index); @@ -5058,7 +5096,7 @@ namespace Slang } } SLANG_RELEASE_ASSERT(resultType); - basePtr = emitFieldAddress(getPtrType(resultType), basePtr, structKey); + basePtr = emitFieldAddress(getPtrType(kIROp_PtrType, resultType, basePtrType->getAddressSpace()), basePtr, structKey); } else { diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 9a773a891..cb8e83df8 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -39,17 +39,6 @@ struct IRModule; struct IRStructField; struct IRStructKey; -enum class AddressSpace -{ - Generic = 0x7fffffff, - ThreadLocal = 1, - Global = 2, - GroupShared = 3, - Uniform = 4, - // specific address space for payload data in metal - MetalObjectData = 5, -}; - typedef unsigned int IROpFlags; enum : IROpFlags { @@ -1710,11 +1699,11 @@ struct IRPtrTypeBase : IRType { IRType* getValueType() { return (IRType*)getOperand(0); } - bool hasAddressSpace() { return getOperandCount() > 1; } + bool hasAddressSpace() { return getOperandCount() > 1 && getAddressSpace() != AddressSpace::Generic; } - IRIntegerValue getAddressSpace() + AddressSpace getAddressSpace() { - return getOperandCount() > 1 ? static_cast<IRIntLit*>(getOperand(1))->getValue() : -1; + return getOperandCount() > 1 ? (AddressSpace)static_cast<IRIntLit*>(getOperand(1))->getValue() : AddressSpace::Generic; } IR_PARENT_ISA(PtrTypeBase) diff --git a/source/slang/slang-legalize-types.cpp b/source/slang/slang-legalize-types.cpp index 393b34e68..66c0044b6 100644 --- a/source/slang/slang-legalize-types.cpp +++ b/source/slang/slang-legalize-types.cpp @@ -209,8 +209,8 @@ bool isPointerToResourceType(IRType* type) { while (auto ptrType = as<IRPtrTypeBase>(type)) { - if (ptrType->getAddressSpace() == SpvStorageClassStorageBuffer || - ptrType->getAddressSpace() == SpvStorageClassPhysicalStorageBufferEXT) + if (ptrType->getAddressSpace() == AddressSpace(SpvStorageClassStorageBuffer) || + ptrType->getAddressSpace() == AddressSpace::UserPointer) return true; type = ptrType->getValueType(); } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index d3770753c..74aa0a0ee 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1855,8 +1855,16 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower auto astValueType = type->getValueType(); IRType* irValueType = lowerType(context, astValueType); - - return getBuilder()->getPtrType(irValueType); + IRInst* addrSpace = nullptr; + if (auto astAddrSpace = type->getAddressSpace()) + { + addrSpace = getSimpleVal(context, lowerVal(context, astAddrSpace)); + } + else + { + addrSpace = getBuilder()->getIntValue(getBuilder()->getUInt64Type(), (IRIntegerValue)AddressSpace::Generic); + } + return getBuilder()->getPtrType(kIROp_PtrType, irValueType, addrSpace); } IRType* visitDeclRefType(DeclRefType* type) @@ -3138,7 +3146,7 @@ void _lowerFuncDeclBaseTypeInfo( irParamType = builder->getInOutType(irParamType); break; case kParameterDirection_Ref: - irParamType = builder->getRefType(irParamType); + irParamType = builder->getRefType(irParamType, AddressSpace::Generic); break; case kParameterDirection_ConstRef: irParamType = builder->getConstRefType(irParamType); @@ -4972,7 +4980,6 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo> case LoweredValInfo::Flavor::Ptr: return LoweredValInfo::ptr( builder->emitElementAddress( - context->irBuilder->getPtrType(type), baseVal.val, indexVal)); diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 0b92d07af..d6efb47b3 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -706,7 +706,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt auto paramType = getParamType(astBuilder, paramDeclRef); if( paramDecl->findModifier<RefModifier>() ) { - paramType = astBuilder->getRefType(paramType); + paramType = astBuilder->getRefType(paramType, AddressSpace::Generic); } else if (paramDecl->findModifier<ConstRefModifier>()) { diff --git a/source/slang/slang-type-system-shared.h b/source/slang/slang-type-system-shared.h index 2f467a05a..404c84cf4 100644 --- a/source/slang/slang-type-system-shared.h +++ b/source/slang/slang-type-system-shared.h @@ -58,6 +58,20 @@ FOREACH_BASE_TYPE(DEFINE_BASE_TYPE) const int kStdlibTextureIsShadowParameterIndex = 6; const int kStdlibTextureIsCombinedParameterIndex = 7; const int kStdlibTextureFormatParameterIndex = 8; + + enum class AddressSpace : uint64_t + { + Generic = 0x7fffffff, + ThreadLocal = 1, + Global = 2, + GroupShared = 3, + Uniform = 4, + // specific address space for payload data in metal + MetalObjectData = 5, + + // Default address space for a user-defined pointer + UserPointer = 0x100000001ULL, + }; } #endif |
