diff options
Diffstat (limited to 'source')
43 files changed, 759 insertions, 199 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 3306403f5..5d2a80c29 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -1278,18 +1278,52 @@ struct __none_t { }; +// @hidden: this type is a BaseType since we want it to work with +// `registerBuiltinDecl` +__builtin_type($((int)BaseType::AddressSpace)) +enum AddressSpace : uint64_t +{ + Device = $((uint64_t)AddressSpace::UserPointer), + GroupShared = $((uint64_t)AddressSpace::GroupShared), +}; + +// @hidden: this type is a BaseType since we want it to work with +// `registerBuiltinDecl` +__builtin_type($((int)BaseType::MemoryScope)) +enum MemoryScope : int32_t +{ + CrossDevice = $((int32_t)MemoryScope::CrossDevice), + Device = $((int32_t)MemoryScope::Device), + Workgroup = $((int32_t)MemoryScope::Workgroup), + Subgroup = $((int32_t)MemoryScope::Subgroup), + Invocation = $((int32_t)MemoryScope::Invocation), + QueueFamily = $((int32_t)MemoryScope::QueueFamily), +} + +// @hidden: this type is a BaseType since we want it to work with +// `registerBuiltinDecl` +__builtin_type($((int)BaseType::AccessQualifier)) +enum Access : uint64_t +{ + ReadWrite = $((uint64_t)AccessQualifier::ReadWrite), + Read = $((uint64_t)AccessQualifier::Read), +} + //@public: /// Represents a pointer type. /// @param T The type of the value pointed to. /// @remarks `T* val` is equivalent to `Ptr<T> val`. -__generic<T, let addrSpace : uint64_t = $((uint64_t)AddressSpace::UserPointer)ULL> __magic_type(PtrType) __intrinsic_type($(kIROp_PtrType)) -struct Ptr +struct Ptr< + T, + Access access = Access::ReadWrite, + AddressSpace addrSpace = AddressSpace::Device> { - __generic<U> + // A user is allowed to explicitly cast between any pointer type of + // the same address space __intrinsic_op($(kIROp_BitCast)) - __init(Ptr<U, addrSpace> ptr); + __init<U, Access accessOther>(Ptr<U, accessOther, addrSpace> ptr); __intrinsic_op($(kIROp_CastIntToPtr)) __init(uint64_t val); @@ -1297,16 +1331,30 @@ struct Ptr __intrinsic_op($(kIROp_CastIntToPtr)) __init(int64_t val); + // By default, getter is not an L value __generic<TInt : __BuiltinIntegerType> __subscript(TInt index) -> T { - // If a 'Ptr[index]' is referred to by a '__ref', call 'kIROp_GetOffsetPtr(index)' __intrinsic_op($(kIROp_GetOffsetPtr)) [nonmutating] ref; } }; +extension<T, AddressSpace addrSpace> Ptr<T, Access::ReadWrite, addrSpace> +{ + // We have a `ref` accessor if we are ReadWrite. This means only `ReadWrite` + // can be used as an L-value. + __generic<TInt : __BuiltinIntegerType> + __subscript(TInt index) -> Ref<T> + { + // If a 'Ptr[index]' is referred to by a '__ref', call 'kIROp_GetOffsetPtr(index)' + __intrinsic_op($(kIROp_GetOffsetPtr)) + [nonmutating] + ref; + } +} + //@hidden: __intrinsic_op($(kIROp_AlignedAttr)) void __align_attr(int alignment); @@ -1348,50 +1396,64 @@ void storeAligned<int alignment, T>(T* ptr, T value) __store_aligned(ptr, value, __align_attr(alignment)); } +${{{ + StringBuilder ptrTypeParameterListBuilder; + ptrTypeParameterListBuilder << "T, Access access, AddressSpace addrSpace"; + String ptrTypeParameterList = ptrTypeParameterListBuilder.toString(); + + StringBuilder ptrArgListBuilder; + ptrArgListBuilder << "T, access, addrSpace"; + String ptrArgList = ptrArgListBuilder.toString(); + + StringBuilder fullPtrTypeBuilder; + fullPtrTypeBuilder << "Ptr<" << ptrArgList << ">"; + String fullPtrType = fullPtrTypeBuilder.toString(); + +}}} //@hidden: __intrinsic_op($(kIROp_Load)) -T __load<T, let addrSpace : uint64_t>(Ptr<T, addrSpace> ptr); +T __load<$(ptrTypeParameterList)>($(fullPtrType) ptr); __intrinsic_op($(kIROp_Store)) -void __store<T, let addrSpace : uint64_t>(Ptr<T, addrSpace> ptr, T val); +void __store<$(ptrTypeParameterList)>($(fullPtrType) ptr, T val); __intrinsic_op($(kIROp_GetElementPtr)) -Ptr<T, addrSpace> __getElementPtr<T, let addrSpace : uint64_t, TIndex : __BuiltinIntegerType>(Ptr<T, addrSpace> ptr, TIndex index); +$(fullPtrType) __getElementPtr<$(ptrTypeParameterList), TIndex : __BuiltinIntegerType>($(fullPtrType) ptr, TIndex index); __intrinsic_op($(kIROp_GetOffsetPtr)) -Ptr<T, addrSpace> __getOffsetPtr<T, let addrSpace : uint64_t, TIndex : __BuiltinIntegerType>(Ptr<T, addrSpace> ptr, TIndex index); +$(fullPtrType) __getOffsetPtr<$(ptrTypeParameterList), TIndex : __BuiltinIntegerType>($(fullPtrType) ptr, TIndex index); -__generic<T, let addrSpace : uint64_t> +__generic<$(ptrTypeParameterList)> __intrinsic_op($(kIROp_Less)) -bool operator <(Ptr<T, addrSpace> p1, Ptr<T, addrSpace> p2); +bool operator <($(fullPtrType) p1, $(fullPtrType) p2); -__generic<T, let addrSpace : uint64_t> +__generic<$(ptrTypeParameterList)> __intrinsic_op($(kIROp_Leq)) -bool operator <=(Ptr<T, addrSpace> p1, Ptr<T, addrSpace> p2); +bool operator <=($(fullPtrType) p1, $(fullPtrType) p2); -__generic<T, let addrSpace : uint64_t> +__generic<$(ptrTypeParameterList)> __intrinsic_op($(kIROp_Greater)) -bool operator>(Ptr<T, addrSpace> p1, Ptr<T, addrSpace> p2); +bool operator>($(fullPtrType) p1, $(fullPtrType) p2); -__generic<T, let addrSpace : uint64_t> +__generic<$(ptrTypeParameterList)> __intrinsic_op($(kIROp_Geq)) -bool operator >=(Ptr<T, addrSpace> p1, Ptr<T, addrSpace> p2); +bool operator >=($(fullPtrType) p1, $(fullPtrType) p2); -__generic<T, let addrSpace : uint64_t> +__generic<$(ptrTypeParameterList)> __intrinsic_op($(kIROp_Neq)) -bool operator !=(Ptr<T, addrSpace> p1, Ptr<T, addrSpace> p2); +bool operator !=($(fullPtrType) p1, $(fullPtrType) p2); -__generic<T, let addrSpace : uint64_t> +__generic<$(ptrTypeParameterList)> __intrinsic_op($(kIROp_Eql)) -bool operator ==(Ptr<T, addrSpace> p1, Ptr<T, addrSpace> p2); +bool operator ==($(fullPtrType) p1, $(fullPtrType) p2); //@public: extension bool : IRangedValue { - __generic<T, let addrSpace : uint64_t> + __generic<$(ptrTypeParameterList)> __implicit_conversion($(kConversionCost_PtrToBool)) __intrinsic_op($(kIROp_CastPtrToBool)) - __init(Ptr<T, addrSpace> ptr); + __init($(fullPtrType) ptr); __generic<T : __EnumType> __implicit_conversion($(kConversionCost_IntegerTruncate)) @@ -1407,9 +1469,9 @@ extension bool : IRangedValue extension uint64_t : IRangedValue { - __generic<T, let addrSpace : uint64_t> + __generic<$(ptrTypeParameterList)> __intrinsic_op($(kIROp_CastPtrToInt)) - __init(Ptr<T, addrSpace> ptr); + __init($(fullPtrType) ptr); static const uint64_t maxValue = 0xFFFFFFFFFFFFFFFFULL; static const uint64_t minValue = 0; @@ -1417,9 +1479,9 @@ extension uint64_t : IRangedValue extension int64_t : IRangedValue { - __generic<T, let addrSpace : uint64_t> + __generic<$(ptrTypeParameterList)> __intrinsic_op($(kIROp_CastPtrToInt)) - __init(Ptr<T, addrSpace> ptr); + __init($(fullPtrType) ptr); static const int64_t maxValue = 0x7FFFFFFFFFFFFFFFLL; static const int64_t minValue = -0x8000000000000000LL; @@ -1427,9 +1489,9 @@ extension int64_t : IRangedValue extension intptr_t : IRangedValue { - __generic<T, let addrSpace : uint64_t> + __generic<$(ptrTypeParameterList)> __intrinsic_op($(kIROp_CastPtrToInt)) - __init(Ptr<T, addrSpace> ptr); + __init($(fullPtrType) ptr); static const intptr_t maxValue = $(SLANG_PROCESSOR_X86_64?"0x7FFFFFFFFFFFFFFFz":"0x7FFFFFFFz"); static const intptr_t minValue = $(SLANG_PROCESSOR_X86_64?"0x8000000000000000z":"0x80000000z"); static const int size = $(SLANG_PROCESSOR_X86_64?"8":"4"); @@ -1437,9 +1499,9 @@ extension intptr_t : IRangedValue extension uintptr_t : IRangedValue { - __generic<T, let addrSpace : uint64_t> + __generic<$(ptrTypeParameterList)> __intrinsic_op($(kIROp_CastPtrToInt)) - __init(Ptr<T, addrSpace> ptr); + __init($(fullPtrType) ptr); static const uintptr_t maxValue = $(SLANG_PROCESSOR_X86_64?"0xFFFFFFFFFFFFFFFFz":"0xFFFFFFFFz"); static const uintptr_t minValue = 0z; static const int size = $(SLANG_PROCESSOR_X86_64?"8":"4"); @@ -1470,7 +1532,9 @@ __intrinsic_type($(kIROp_ConstRefType)) struct ConstRef {}; -typealias __Addr<T> = Ptr<T, $((uint64_t)AddressSpace::Generic)ULL>; +// __Addr<T> is AddressSpace::Generic since Slang will specalize & validate the address-space +// internally to a concrete address-space. +typealias __Addr<T> = Ptr<T, Access::ReadWrite, (AddressSpace)$((uint64_t)AddressSpace::Generic)>; //@public: @@ -1828,16 +1892,16 @@ struct NativeString __init() { this = NativeString(""); } }; -extension Ptr<void> +extension<Access access> Ptr<void, access> { __implicit_conversion($(kConversionCost_PtrToVoidPtr)) [__unsafeForceInlineEarly] - __init(NativeString nativeStr) { this = nativeStr.getBuffer(); } + __init(NativeString nativeStr) { this = Ptr<void, access>(nativeStr.getBuffer()); } - __generic<T, let addrSpace : uint64_t> + __generic<$(ptrTypeParameterList)> __intrinsic_op($(kIROp_BitCast)) __implicit_conversion($(kConversionCost_PtrToVoidPtr)) - __init(Ptr<T, addrSpace> ptr); + __init($(fullPtrType) ptr); __generic<T> __intrinsic_op($(kIROp_BitCast)) @@ -2607,29 +2671,31 @@ for (auto op : intrinsicUnaryOps) }}}} -__generic<T, let addrSpace : uint64_t> +// Only ReadWrite is an L-value. +__generic<T, AddressSpace addrSpace> __intrinsic_op(0) -[require(cpp_cuda_spirv)] -__prefix Ref<T> operator*(Ptr<T, addrSpace> value); +__prefix Ref<T> operator*(Ptr<T, Access::ReadWrite, addrSpace> value); -__generic<T> +// Unknown access qualifier or Access::Read access qualifier is a promise +// that the pointer is not going to be used as an L-value. +__generic<$(ptrTypeParameterList)> __intrinsic_op(0) -[KnownBuiltin($( (int)KnownBuiltinDeclName::OperatorAddressOf))] -[require(cpp_cuda_spirv)] -__prefix Ptr<T, $((uint64_t)AddressSpace::UserPointer)ULL> operator&(__ref T value); +__prefix ConstRef<T> operator*($(fullPtrType) value); +// TODO: [require(cpu)]. This cannot be done yet since this change breaks slangpy __generic<T> __intrinsic_op(0) +[KnownBuiltin( $((int)KnownBuiltinDeclName::OperatorAddressOf))] [require(cpp_cuda_spirv)] -__Addr<T> __get_addr( __ref T value); +__prefix Ptr<T, Access::ReadWrite, AddressSpace::Device> operator&(__ref T value); -__generic<T, let addrSpace : uint64_t> +__generic<$(ptrTypeParameterList), TInt : __BuiltinIntegerType> __intrinsic_op($(kIROp_GetOffsetPtr)) -Ptr<T, addrSpace> operator+(Ptr<T, addrSpace> value, int64_t offset); +$(fullPtrType) operator+($(fullPtrType) value, TInt offset); -__generic<T, let addrSpace : uint64_t> +__generic<$(ptrTypeParameterList), TInt : __BuiltinIntegerType> [__unsafeForceInlineEarly] -Ptr<T, addrSpace> operator -(Ptr<T, addrSpace> value, int64_t offset) +$(fullPtrType) operator-($(fullPtrType) value, TInt offset) { return __getOffsetPtr(value, -offset); } @@ -2694,9 +2760,9 @@ matrix<T,R,C> operator$(op.name)(in out matrix<T,R,C,L> value) {$(fixity.bodyPrefix) value = value $(op.binOp) __builtin_cast<T>(1); return $(fixity.returnVal); } $(fixity.qual) -__generic<T, let addrSpace : uint64_t> +__generic<$(ptrTypeParameterList)> [__unsafeForceInlineEarly] -Ptr<T, addrSpace> operator$(op.name)(in out Ptr<T, addrSpace> value) +$(fullPtrType) operator$(op.name)(in out $(fullPtrType) value) {$(fixity.bodyPrefix) value = value $(op.binOp) 1; return $(fixity.returnVal); } ${{{{ @@ -3556,18 +3622,6 @@ enum MemoryOrder SeqCst = $(kIRMemoryOrder_SeqCst), } -// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_scope_id -enum MemoryScope -{ - CrossDevice = 0, - Device = 1, - Workgroup = 2, - Subgroup = 3, - Invocation = 4, - QueueFamily = 5, - ShaderCallKHR = 6, -}; - /// Represents types that can be used in any atomic operations. /// Implemented by builtin scalar types: `int`, `uint`, `int64_t`, `uint64_t`, `int8_t`, `uint8_t`, `int16_t`, `uint16_t`, `float`, `double` and `half`. [sealed] interface IAtomicable {} @@ -4307,7 +4361,7 @@ __attributeTarget(FuncDecl) attribute_syntax [RequireFullQuads] : RequireFullQuadsAttribute; __generic<T> -typealias NodePayloadPtr = Ptr<T, $((uint64_t)AddressSpace::NodePayloadAMDX)>; +typealias NodePayloadPtr = Ptr<T, Access::ReadWrite, (AddressSpace)$((uint64_t)AddressSpace::NodePayloadAMDX)>; __attributeTarget(StructDecl) attribute_syntax [raypayload] : RayPayloadAttribute; diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 23afb3297..2af0dbcf7 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -19859,7 +19859,7 @@ __Addr<T> __allocHitObjectAttributes<T>() { [__vulkanHitObjectAttributes] static T t; - return __get_addr(t); + return __getAddress(t); } // Next is the custom intrinsic that will compute the hitObjectAttributes location diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index a71abf570..5da4e9521 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -461,9 +461,17 @@ Type* ASTBuilder::getSpecializedBuiltinType(ArrayView<Val*> genericArgs, const c return rsType; } -PtrType* ASTBuilder::getPtrType(Type* valueType, AddressSpace addrSpace) +PtrType* ASTBuilder::getPtrType(Type* valueType, Val* accessQualifier, Val* addrSpace) { - return dynamicCast<PtrType>(getPtrType(valueType, addrSpace, "PtrType")); + return dynamicCast<PtrType>(getPtrType(valueType, accessQualifier, addrSpace, "PtrType")); +} + +PtrType* ASTBuilder::getPtrType( + Type* valueType, + AccessQualifier accessQualifier, + AddressSpace addrSpace) +{ + return dynamicCast<PtrType>(getPtrType(valueType, accessQualifier, addrSpace, "PtrType")); } Type* ASTBuilder::getDefaultLayoutType() @@ -489,11 +497,6 @@ Type* ASTBuilder::getScalarLayoutType() return getSpecializedBuiltinType({}, "ScalarDataLayoutType"); } -Type* ASTBuilder::getCLayoutType() -{ - return getSpecializedBuiltinType({}, "CDataLayoutType"); -} - // Construct the type `Out<valueType>` OutType* ASTBuilder::getOutType(Type* valueType) { @@ -505,9 +508,9 @@ InOutType* ASTBuilder::getInOutType(Type* valueType) return dynamicCast<InOutType>(getPtrType(valueType, "InOutType")); } -RefType* ASTBuilder::getRefType(Type* valueType, AddressSpace addrSpace) +RefType* ASTBuilder::getRefType(Type* valueType) { - return dynamicCast<RefType>(getPtrType(valueType, addrSpace, "RefType")); + return dynamicCast<RefType>(getPtrType(valueType, "RefType")); } ConstRefType* ASTBuilder::getConstRefType(Type* valueType) @@ -528,13 +531,27 @@ PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, char const* ptrTypeName) PtrTypeBase* ASTBuilder::getPtrType( Type* valueType, - AddressSpace addrSpace, + Val* accessQualifier, + Val* addrSpace, char const* ptrTypeName) { - Val* args[] = {valueType, getIntVal(getUInt64Type(), (IntegerLiteralValue)addrSpace)}; + Val* args[] = {valueType, accessQualifier, addrSpace}; return as<PtrTypeBase>(getSpecializedBuiltinType(makeArrayView(args), ptrTypeName)); } +PtrTypeBase* ASTBuilder::getPtrType( + Type* valueType, + AccessQualifier accessQualifier, + AddressSpace addrSpace, + char const* ptrTypeName) +{ + return as<PtrTypeBase>(getPtrType( + valueType, + getIntVal(getBuiltinType(BaseType::AccessQualifier), (IntegerLiteralValue)accessQualifier), + getIntVal(getBuiltinType(BaseType::AddressSpace), (IntegerLiteralValue)addrSpace), + ptrTypeName)); +} + ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* elementCount) { if (!elementCount) diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index 798e1ddc0..e71e2665f 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -530,7 +530,8 @@ public: Type* getDiffInterfaceType() { return m_sharedASTBuilder->getDiffInterfaceType(); } // Construct the type `Ptr<valueType>`, where `Ptr` // is looked up as a builtin type. - PtrType* getPtrType(Type* valueType, AddressSpace addrSpace); + PtrType* getPtrType(Type* valueType, AccessQualifier accessQualifier, AddressSpace addrSpace); + PtrType* getPtrType(Type* valueType, Val* accessQualifier, Val* addrSpace); // Construct the type `Out<valueType>` OutType* getOutType(Type* valueType); @@ -539,7 +540,7 @@ public: InOutType* getInOutType(Type* valueType); // Construct the type `Ref<valueType>` - RefType* getRefType(Type* valueType, AddressSpace addrSpace); + RefType* getRefType(Type* valueType); // Construct the type `ConstRef<valueType>` ConstRefType* getConstRefType(Type* valueType); @@ -550,7 +551,16 @@ public: // Construct a pointer type like `Ptr<valueType>`, but where // the actual type name for the pointer type is given by `ptrTypeName` PtrTypeBase* getPtrType(Type* valueType, char const* ptrTypeName); - PtrTypeBase* getPtrType(Type* valueType, AddressSpace addrSpace, char const* ptrTypeName); + PtrTypeBase* getPtrType( + Type* valueType, + Val* accessQualifier, + Val* addrSpace, + char const* ptrTypeName); + PtrTypeBase* getPtrType( + Type* valueType, + AccessQualifier accessQualifier, + AddressSpace addrSpace, + char const* ptrTypeName); ArrayExpressionType* getArrayType(Type* elementType, IntVal* elementCount); diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index adbc7a2ba..fb0ac2a67 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -503,6 +503,13 @@ class CountOfExpr : public SizeOfLikeExpr }; FIDDLE() +class AddressOfExpr : public Expr +{ + FIDDLE(...) + FIDDLE() Expr* arg = nullptr; +}; + +FIDDLE() class MakeOptionalExpr : public Expr { FIDDLE(...) diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index c29a42665..379de0560 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -251,6 +251,7 @@ struct ASTIterator iterator->maybeDispatchCallback(expr); } void visitReturnValExpr(ReturnValExpr* expr) { iterator->maybeDispatchCallback(expr); } + void visitAddressOfExpr(AddressOfExpr* expr) { iterator->maybeDispatchCallback(expr); } void visitAndTypeExpr(AndTypeExpr* expr) { diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 2f9c29d6c..89f7a70bb 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -583,7 +583,6 @@ class BuiltinRequirementModifier : public Modifier FIDDLE() BuiltinRequirementKind kind; }; - // A modifier applied to declarations of builtin types to indicate how they // should be lowered to the IR. // diff --git a/source/slang/slang-ast-print.cpp b/source/slang/slang-ast-print.cpp index 1f7478662..299e6a859 100644 --- a/source/slang/slang-ast-print.cpp +++ b/source/slang/slang-ast-print.cpp @@ -610,6 +610,15 @@ void ASTPrinter::addExpr(Expr* expr) } sb << ")"; } + else if (const auto addressOfExpr = as<AddressOfExpr>(expr)) + { + sb << "__getAddress("; + if (addressOfExpr->arg) + { + addExpr(addressOfExpr->arg); + } + sb << ")"; + } else if (const auto makeOptionalExpr = as<MakeOptionalExpr>(expr)) { if (makeOptionalExpr->value) diff --git a/source/slang/slang-ast-support-types.cpp b/source/slang/slang-ast-support-types.cpp index 3ac352f0a..1f5f4f8b2 100644 --- a/source/slang/slang-ast-support-types.cpp +++ b/source/slang/slang-ast-support-types.cpp @@ -15,6 +15,10 @@ QualType::QualType(Type* type) { isLeftValue = true; } + else if (as<ConstRefType>(type)) + { + isLeftValue = false; + } } void removeModifier(ModifiableSyntaxNode* syntax, Modifier* toRemove) diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 8a224b305..53f6626d7 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -439,11 +439,28 @@ Type* NativeRefType::getValueType() return as<Type>(_getGenericTypeArg(this, 0)); } -Val* PtrTypeBase::getAddressSpace() + +Val* PtrTypeBase::getAccessQualifier() { return _getGenericTypeArg(this, 1); } +Val* PtrTypeBase::getAddressSpace() +{ + return _getGenericTypeArg(this, 2); +} + +AccessQualifier tryGetAccessQualifierValue(Val* val) +{ + AccessQualifier accessQualifier = AccessQualifier::ReadWrite; + + if (auto cintVal = as<ConstantIntVal>(val)) + { + accessQualifier = (AccessQualifier)(cintVal->getValue()); + } + return accessQualifier; +} + AddressSpace tryGetAddressSpaceValue(Val* addrSpaceVal) { AddressSpace addrSpace = AddressSpace::Generic; @@ -460,19 +477,38 @@ void maybePrintAddrSpaceOperand(StringBuilder& out, AddressSpace addrSpace) switch (addrSpace) { case AddressSpace::Generic: + out << toSlice(", AddressSpace::Generic"); + break; case AddressSpace::UserPointer: + // We expose UserPointer as Device to users + out << toSlice(", AddressSpace::Device"); break; case AddressSpace::GroupShared: - out << toSlice(", groupshared"); + out << toSlice(", AddressSpace::GroupShared"); break; case AddressSpace::Global: - out << toSlice(", global"); + out << toSlice(", AddressSpace::Global"); break; case AddressSpace::ThreadLocal: - out << toSlice(", threadlocal"); + out << toSlice(", AddressSpace::ThreadLocal"); break; case AddressSpace::Uniform: - out << toSlice(", uniform"); + out << toSlice(", AddressSpace::Uniform"); + break; + default: + break; + } +} + +void maybePrintAccessQualifierOperand(StringBuilder& out, AccessQualifier accessQualifier) +{ + switch (accessQualifier) + { + case AccessQualifier::ReadWrite: + out << toSlice(", Access::ReadWrite"); + break; + case AccessQualifier::Read: + out << toSlice(", Access::Read"); break; default: break; @@ -481,20 +517,21 @@ void maybePrintAddrSpaceOperand(StringBuilder& out, AddressSpace addrSpace) void PtrType::_toTextOverride(StringBuilder& out) { + auto accessQualifier = tryGetAccessQualifierValue(getAccessQualifier()); auto addrSpace = tryGetAddressSpaceValue(getAddressSpace()); - if (addrSpace == AddressSpace::Generic) - out << toSlice("Addr<") << getValueType(); - else - out << toSlice("Ptr<") << getValueType(); + out << toSlice("Ptr<") << getValueType(); + maybePrintAccessQualifierOperand(out, accessQualifier); maybePrintAddrSpaceOperand(out, addrSpace); out << toSlice(">"); } void RefType::_toTextOverride(StringBuilder& out) { + auto accessQualifier = tryGetAccessQualifierValue(getAccessQualifier()); + auto addrSpace = tryGetAddressSpaceValue(getAddressSpace()); out << toSlice("Ref<") << getValueType(); - auto addressSpaceVal = getAddressSpace(); - maybePrintAddrSpaceOperand(out, tryGetAddressSpaceValue(addressSpaceVal)); + maybePrintAccessQualifierOperand(out, accessQualifier); + maybePrintAddrSpaceOperand(out, addrSpace); out << toSlice(">"); } diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 842af8b88..4994328b2 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -685,7 +685,7 @@ class PtrTypeBase : public BuiltinType FIDDLE(...) // Get the type of the pointed-to value. Type* getValueType(); - + Val* getAccessQualifier(); Val* getAddressSpace(); }; diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index 43bd99f19..96f0a1682 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -1435,6 +1435,10 @@ Val* TypeCastIntVal::tryFoldImpl( case BaseType::UInt8: resultValue = (uint8_t)resultValue; return true; + case BaseType::AddressSpace: + case BaseType::AccessQualifier: + case BaseType::MemoryScope: + return true; default: return false; } diff --git a/source/slang/slang-base-type-info.cpp b/source/slang/slang-base-type-info.cpp index 9072e34e4..984437ca8 100644 --- a/source/slang/slang-base-type-info.cpp +++ b/source/slang/slang-base-type-info.cpp @@ -4,7 +4,7 @@ namespace Slang { -/* static */ const BaseTypeInfo BaseTypeInfo::s_info[Index(BaseType::CountOf)] = { +/* static */ const BaseTypeInfo BaseTypeInfo::s_info[Index(BaseType::CountOfPrimitives)] = { {0, 0, uint8_t(BaseType::Void)}, {uint8_t(sizeof(bool)), 0, uint8_t(BaseType::Bool)}, {uint8_t(sizeof(int8_t)), @@ -84,6 +84,12 @@ namespace Slang return UnownedStringSlice::fromLiteral("intptr_t"); case BaseType::UIntPtr: return UnownedStringSlice::fromLiteral("uintptr_t"); + case BaseType::AddressSpace: + return UnownedStringSlice::fromLiteral("AddressSpace"); + case BaseType::MemoryScope: + return UnownedStringSlice::fromLiteral("MemoryScope"); + case BaseType::AccessQualifier: + return UnownedStringSlice::fromLiteral("Access"); default: { SLANG_ASSERT(!"Unknown basic type"); diff --git a/source/slang/slang-base-type-info.h b/source/slang/slang-base-type-info.h index 4b96af18f..bad70c6fa 100644 --- a/source/slang/slang-base-type-info.h +++ b/source/slang/slang-base-type-info.h @@ -44,7 +44,7 @@ struct BaseTypeInfo static bool check(); private: - static const BaseTypeInfo s_info[Index(BaseType::CountOf)]; + static const BaseTypeInfo s_info[Index(BaseType::CountOfPrimitives)]; }; } // namespace Slang diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef index 0ea43a8df..822356312 100644 --- a/source/slang/slang-capabilities.capdef +++ b/source/slang/slang-capabilities.capdef @@ -271,6 +271,10 @@ alias cpp_cuda = cpp | cuda; /// [Compound] alias cpp_cuda_spirv = cpp | cuda | spirv; +/// CPP, CUDA, Metal, and SPIRV code-gen targets +/// [Compound] +alias cpp_cuda_metal_spirv = cpp | cuda | metal | spirv; + /// CUDA and SPIRV code-gen targets /// [Compound] alias cuda_spirv = cuda | spirv; diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 7c6f8929a..9355834b5 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -121,7 +121,7 @@ Type* SemanticsVisitor::_tryJoinTypeWithInterface( ConversionCost bestCost = kConversionCost_Explicit; if (auto basicType = dynamicCast<BasicExpressionType>(type)) { - for (Int baseTypeFlavorIndex = 0; baseTypeFlavorIndex < Int(BaseType::CountOf); + for (Int baseTypeFlavorIndex = 0; baseTypeFlavorIndex < Int(BaseType::CountOfPrimitives); baseTypeFlavorIndex++) { // Don't consider `type`, since we already know it doesn't work. diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 758c23a5f..4d15fd840 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -1760,6 +1760,13 @@ bool SemanticsVisitor::_coerce( if (sink) { sink->diagnose(fromExpr, Diagnostics::ambiguousConversion, fromType, toType); + for (auto candidate : overloadContext.bestCandidates) + { + sink->diagnose( + candidate.item.declRef, + Diagnostics::seeDeclarationOf, + candidate.item.declRef); + } } *outToExpr = CreateErrorExpr(fromExpr); diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 507e12fa6..e59cf6ad5 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -678,8 +678,12 @@ struct SemanticsDeclReferenceVisitor : public SemanticsDeclVisitorBase, return; return DeclVisitor<VisitorType>::dispatch(val); } + // Expr Visitor void visitExpr(Expr*) {} + + void visitOpenRefExpr(OpenRefExpr* expr) { dispatchIfNotNull(expr->innerExpr); } + void visitIndexExpr(IndexExpr* subscriptExpr) { for (auto arg : subscriptExpr->indexExprs) @@ -695,6 +699,7 @@ struct SemanticsDeclReferenceVisitor : public SemanticsDeclVisitorBase, dispatchIfNotNull(element); } + void visitAddressOfExpr(AddressOfExpr* expr) { dispatchIfNotNull(expr->arg); } void visitAssignExpr(AssignExpr* expr) { @@ -2360,6 +2365,13 @@ void SemanticsDeclHeaderVisitor::checkVarDeclCommon(VarDeclBase* varDecl) addModifier(varDecl, m_astBuilder->create<ExternCppModifier>()); } + // Not allowed a `globallycoherent T*` or related + if (as<PtrType>(varDecl->type)) + if (auto memoryQualifierSet = varDecl->findModifier<MemoryQualifierSetModifier>()) + if (memoryQualifierSet->getMemoryQualifierBit() & + MemoryQualifierSetModifier::Flags::kCoherent) + getSink()->diagnose(varDecl, Diagnostics::coherentKeywordOnAPointer); + // Check for static const variables without initializers if (!varDecl->initExpr) { @@ -14379,6 +14391,12 @@ struct CapabilityDeclReferenceVisitor { handleProcessFunc(stmt, CapabilitySet(CapabilityName::fragment), stmt->loc); } + void visitAddressOfExpr(AddressOfExpr* expr) + { + // __getAddress only works with certain targets + handleProcessFunc(expr, CapabilitySet(CapabilityName::cpp_cuda_metal_spirv), expr->loc); + this->dispatchIfNotNull(expr->arg); + } void visitTargetSwitchStmt(TargetSwitchStmt* stmt) { CapabilitySet set; diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index a874eaf43..ec249f56d 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -230,6 +230,7 @@ Expr* SemanticsVisitor::maybeOpenRef(Expr* expr) openRef->type.isLeftValue = (as<RefType>(exprType) != nullptr); openRef->type.type = refType->getValueType(); openRef->checked = true; + openRef->loc = expr->loc; return openRef; } return expr; @@ -4111,6 +4112,167 @@ Expr* SemanticsExprVisitor::visitSizeOfLikeExpr(SizeOfLikeExpr* sizeOfLikeExpr) return sizeOfLikeExpr; } +// Determines if we have a valid `AddressOf` target. +// Target to validate is `baseExpr`. +// Original type is `targetType`. +static PtrType* getValidTypeForAddressOf( + SemanticsVisitor* visitor, + ASTBuilder* m_astBuilder, + Expr* baseExpr, + Type* targetType) +{ + + // If our base is a variable like expression, we should check if this expr is a + // block of memory we allow getting the address of. + if (auto declRefExpr = as<DeclRefExpr>(baseExpr)) + { + visitor->ensureDecl(declRefExpr->declRef, DeclCheckState::DefinitionChecked); + if (auto varDeclRef = as<VarDeclBase>(declRefExpr->declRef)) + { + auto variableType = varDeclRef.substitute(m_astBuilder, targetType); + auto varDecl = varDeclRef.getDecl(); + bool hasVulkanHitObjectAttributesAttribute = false; + bool hasHLSLGroupSharedModifier = false; + for (auto modifier : varDecl->modifiers) + { + if (as<VulkanHitObjectAttributesAttribute>(modifier)) + hasVulkanHitObjectAttributesAttribute = true; + else if (as<HLSLGroupSharedModifier>(modifier)) + hasHLSLGroupSharedModifier = true; + + if (hasVulkanHitObjectAttributesAttribute || hasHLSLGroupSharedModifier) + break; + } + + // Handle variables tagged as [__vulkanHitObjectAttributes]. + // This support is needed for an internal "hack" Slang uses + // for raytracing with `__allocHitObjectAttributes`. + if (hasVulkanHitObjectAttributesAttribute) + { + return m_astBuilder->getPtrType( + variableType, + AccessQualifier::ReadWrite, + AddressSpace::Generic); + } + // Handle 'groupshared' variables. + else if (hasHLSLGroupSharedModifier) + { + return m_astBuilder->getPtrType( + variableType, + AccessQualifier::ReadWrite, + AddressSpace::GroupShared); + } + } + } + + // If our base is a variable like expression, which comes from a deref-like operation, + // we should check if we are able to return a pointer from that base. + auto getPtrTypeFromBaseOfDerefLikeOperation = [&](Expr* baseExpr) -> PtrType* + { + auto declRefExpr = as<DeclRefExpr>(baseExpr); + if (!declRefExpr) + return nullptr; + visitor->ensureDecl(declRefExpr->declRef, DeclCheckState::DefinitionChecked); + auto varDeclRef = as<VarDeclBase>(declRefExpr->declRef); + if (!varDeclRef) + return nullptr; + + auto variableType = varDeclRef.substitute(m_astBuilder, targetType); + + auto ptrType = as<PtrType>(getType(m_astBuilder, varDeclRef)); + if (!ptrType) + return nullptr; + + return m_astBuilder->getPtrType( + variableType, + ptrType->getAccessQualifier(), + ptrType->getAddressSpace()); + }; + + // This logic handles the recursive lookup of "does our operation lead up + // to an addressessable (can take the address-of) section of memory". + if (auto indexExpr = as<IndexExpr>(baseExpr)) + { + // If a user chooses to index into an array, we should check if the base + // expression is something we can get the address-of. + return getValidTypeForAddressOf( + visitor, + m_astBuilder, + indexExpr->baseExpression, + targetType); + } + else if (auto memberExpr = as<MemberExpr>(baseExpr)) + { + // If a user chooses to get a member of a base, we should check if the base + // is something we can get the address-of. + if (as<VarDeclBase>(memberExpr->declRef)) + return getValidTypeForAddressOf( + visitor, + m_astBuilder, + memberExpr->baseExpression, + targetType); + } + else if (auto derefExpr = as<DerefExpr>(baseExpr)) + { + // If a user deref's a variable-like-expression, we should + // check if this is a base expression we can get the address-of. + return getPtrTypeFromBaseOfDerefLikeOperation(derefExpr->base); + } + else if (auto invokeExpr = as<InvokeExpr>(baseExpr)) + { + // We only want to allow function calls if we are getting the address + // of a `GetOffsetPtr` to a pointer-variable + auto functionMemberExpr = as<MemberExpr>(invokeExpr->functionExpr); + if (!functionMemberExpr) + return nullptr; + auto subscriptDecl = as<SubscriptDecl>(functionMemberExpr->declRef.getDecl()); + if (!subscriptDecl) + return nullptr; + bool isOffsetIntrinsicOp = false; + for (auto refAccessor : subscriptDecl->getMembersOfType<RefAccessorDecl>()) + { + auto intrinsicOp = refAccessor->findModifier<IntrinsicOpModifier>(); + if (!intrinsicOp) + continue; + if (intrinsicOp->op != kIROp_GetOffsetPtr) + continue; + isOffsetIntrinsicOp = true; + } + if (!isOffsetIntrinsicOp) + return nullptr; + + return getPtrTypeFromBaseOfDerefLikeOperation(functionMemberExpr->baseExpression); + } + else if (auto swizzleExpr = as<SwizzleExpr>(baseExpr)) + { + // Only allow swizzle of 1 element since otherwise + // we may have a non-contiguous swizzle + // (`val.xxy` is non contiguous). + if (swizzleExpr->elementIndices.getCount() > 1) + return nullptr; + + // Check if the base expression is something we can get the address-of. + return getValidTypeForAddressOf(visitor, m_astBuilder, swizzleExpr->base, targetType); + } + return nullptr; +} + +Expr* SemanticsExprVisitor::visitAddressOfExpr(AddressOfExpr* expr) +{ + expr->arg = CheckTerm(expr->arg); + + // This address-of feature is purely experimental and for prototyping. + // Only allow known expressions. + expr->type = + getValidTypeForAddressOf(this, m_astBuilder, expr->arg, getType(m_astBuilder, expr->arg)); + if (!expr->type) + { + getSink()->diagnose(expr, Diagnostics::invalidAddressOf); + expr->type = m_astBuilder->getErrorType(); + } + return expr; +} + Expr* SemanticsExprVisitor::visitBuiltinCastExpr(BuiltinCastExpr* expr) { // All builtin cast exprs should already be checked. @@ -5744,7 +5906,10 @@ Expr* SemanticsExprVisitor::visitPointerTypeExpr(PointerTypeExpr* expr) expr->base = CheckProperType(expr->base); if (as<ErrorType>(expr->base.type)) expr->type = expr->base.type; - auto ptrType = m_astBuilder->getPtrType(expr->base.type, AddressSpace::UserPointer); + auto ptrType = m_astBuilder->getPtrType( + expr->base.type, + AccessQualifier::ReadWrite, + AddressSpace::UserPointer); expr->type = m_astBuilder->getTypeType(ptrType); return expr; } @@ -5830,6 +5995,11 @@ Val* SemanticsExprVisitor::checkTypeModifier(Modifier* modifier, Type* type) { return m_astBuilder->getNoDiffModifierVal(); } + else if (as<ConstModifier>(modifier)) + { + getSink()->diagnose(modifier, Diagnostics::constNotAllowedOnType); + return nullptr; + } else { // TODO: more complete error message here diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 523041697..e6c66ddd3 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -3037,7 +3037,7 @@ public: } Expr* visitSizeOfLikeExpr(SizeOfLikeExpr* expr); - + Expr* visitAddressOfExpr(AddressOfExpr* expr); Expr* visitIncompleteExpr(IncompleteExpr* expr); Expr* visitBoolLiteralExpr(BoolLiteralExpr* expr); Expr* visitNullPtrLiteralExpr(NullPtrLiteralExpr* expr); diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index db477ac25..ebda2d637 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -1677,6 +1677,18 @@ Modifier* SemanticsVisitor::checkModifier( } } + if (as<ConstModifier>(m)) + { + if (auto varDeclBase = as<VarDeclBase>(syntaxNode)) + { + if (as<PointerTypeExpr>(varDeclBase->type.exp)) + { + // Disallow `const T*` syntax. + getSink()->diagnose(m, Diagnostics::constNotAllowedOnCStylePtrDecl); + return nullptr; + } + } + } if (auto glslLayoutAttribute = as<UncheckedGLSLLayoutAttribute>(m)) { return checkGLSLLayoutAttribute(glslLayoutAttribute, syntaxNode); diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index f949e2632..e0dbc7e08 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1196,8 +1196,25 @@ Expr* SemanticsVisitor::CompleteOverloadCandidate( { // If the subscript decl has a setter, // then the call is an l-value if base is l-value. + // + // If Ptr<T, Access> we only need to check for ReadWrite + // Access (if ReadWrite result is an LValue. By default a + // Ptr<...> is Read-only (unresolved generic argument & Access::Read). if (auto base = GetBaseExpr(baseExpr)) { + if (auto ptrTypeBase = as<PtrTypeBase>(base->type)) + { + auto accessQualifier = + as<ConstantIntVal>(ptrTypeBase->getAccessQualifier()); + if (!accessQualifier || + AccessQualifier(accessQualifier->getValue()) == + AccessQualifier::ReadWrite) + { + callExpr->type.isLeftValue = true; + } + break; + } + if (base->type.isLeftValue) { callExpr->type.isLeftValue = true; diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index 45deea109..7c9111629 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -784,7 +784,7 @@ Type* getParamTypeWithDirectionWrapper(ASTBuilder* astBuilder, DeclRef<VarDeclBa case kParameterDirection_InOut: return astBuilder->getInOutType(result); case kParameterDirection_Ref: - return astBuilder->getRefType(result, AddressSpace::Generic); + return astBuilder->getRefType(result); default: return result; } diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index ecc3fb3f3..a30b5f362 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -524,6 +524,14 @@ DIAGNOSTIC( Error, missingLayoutBindingModifier, "Expecting 'binding' modifier in the layout qualifier here") +DIAGNOSTIC( + 20017, + Error, + constNotAllowedOnCStylePtrDecl, + "'const' not allowed on pointer typed declarations using the C style '*' operator. " + "If the intent is to restrict the pointed-to value to read-only, use 'Ptr<T, Access.Read>'; " + "if the intent is to make the pointer itself immutable, use 'let' or 'const Ptr<...>'.") +DIAGNOSTIC(20018, Error, constNotAllowedOnType, "cannot use 'const' as a type modifier") DIAGNOSTIC( 20101, @@ -702,11 +710,6 @@ DIAGNOSTIC( argumentExpectedLValue, "argument passed to parameter '$0' must be l-value.") DIAGNOSTIC( - 30078, - Error, - cannotTakeConstantPointers, - "Not allowed to take pointer of an immutable object") -DIAGNOSTIC( 30048, Error, argumentHasMoreMemoryQualifiersThanParam, @@ -823,7 +826,17 @@ DIAGNOSTIC( "function, you can replace '$2 $0' with a generic 'T $0' and a 'where T : $2' constraint.") DIAGNOSTIC(-1, Note, doYouMeanStaticConst, "do you intend to define a `static const` instead?") DIAGNOSTIC(-1, Note, doYouMeanUniform, "do you intend to define a `uniform` parameter instead?") - +DIAGNOSTIC( + 30078, + Error, + coherentKeywordOnAPointer, + "cannot have a `globallycoherent T*` or a `coherent T*`, use explicit methods for coherent " + "operations instead") +DIAGNOSTIC( + 30079, + Error, + cannotTakeConstantPointers, + "Not allowed to take the address of an immutable object") DIAGNOSTIC( 30100, Error, @@ -927,11 +940,7 @@ DIAGNOSTIC( Note, noteExplicitConversionPossible, "explicit conversion from '$0' to '$1' is possible") -DIAGNOSTIC( - 30080, - Error, - ambiguousConversion, - "more than one implicit conversion exists from '$0' to '$1'") +DIAGNOSTIC(30080, Error, ambiguousConversion, "more than one conversion exists from '$0' to '$1'") DIAGNOSTIC( 30081, Warning, @@ -1432,7 +1441,11 @@ DIAGNOSTIC( "If this is intended, consider using [NoDiffThis] on the function '$1' to suppress this " "warning. Alternatively, users can mark the parent struct as [Differentiable] to propagate " "derivatives.") - +DIAGNOSTIC( + 31160, + Error, + invalidAddressOf, + "'__getAddress' only supports groupshared variables and members of groupshared/device memory.") DIAGNOSTIC(31200, Warning, deprecatedUsage, "$0 has been deprecated: $1") DIAGNOSTIC(31201, Error, modifierNotAllowed, "modifier '$0' is not allowed here.") DIAGNOSTIC( @@ -2682,6 +2695,8 @@ DIAGNOSTIC( "cannot perform atomic operation because destination is neither groupshared nor from a device " "buffer.") +DIAGNOSTIC(41404, Error, cannotWriteToReadOnlyPointer, "cannot write to a read-only pointer") + // // 5xxxx - Target code generation. // diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 2127141bd..5036333c1 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1844,7 +1844,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex (IRPtrTypeBase*)inst)) emitOpTypeForwardPointer(resultSpvType, storageClass); } - if (storageClass == SpvStorageClassPhysicalStorageBuffer) + if (storageClass == SpvStorageClassPhysicalStorageBuffer || + storageClass == SpvStorageClassStorageBuffer) { if (m_decoratedSpvInsts.add(getID(resultSpvType))) { @@ -3271,6 +3272,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex builder.getPtrType( kIROp_PtrType, spvAsmBuiltinVar->getDataType(), + AccessQualifier::ReadWrite, AddressSpace::BuiltinInput), kind, spvAsmBuiltinVar); @@ -3868,20 +3870,24 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex return spvDebugLocalVar; } - bool isLegalType(IRInst* type) + // Returns true if the given type is allowed to emit for a `DebugVar`. + // Other types may not be illegal, but Slang currently does not support + // emitting these other DebugVar types. + bool isAllowedDebugVarType(IRInst* type) { switch (type->getOp()) { case kIROp_UnsizedArrayType: return false; case kIROp_ArrayType: - return isLegalType(as<IRArrayType>(type)->getElementType()); + return isAllowedDebugVarType(as<IRArrayType>(type)->getElementType()); case kIROp_VectorType: case kIROp_StructType: case kIROp_MatrixType: return true; case kIROp_PtrType: - return as<IRPtrTypeBase>(type)->getAddressSpace() == AddressSpace::UserPointer; + return as<IRPtrTypeBase>(type)->getAddressSpace() == AddressSpace::UserPointer || + as<IRPtrTypeBase>(type)->getAddressSpace() == AddressSpace::GroupShared; default: if (as<IRBasicType>(type)) return true; @@ -3899,7 +3905,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex builder.setInsertBefore(debugVar); auto varType = tryGetPointedToType(&builder, debugVar->getDataType()); - if (!isLegalType(varType)) + if (!isAllowedDebugVarType(varType)) return nullptr; IRSizeAndAlignment sizeAlignment; @@ -9473,10 +9479,16 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex { if (auto ptrType = as<IRPtrTypeBase>(type)) { - if (ptrType->getAddressSpace() == AddressSpace::StorageBuffer) + switch (ptrType->getAddressSpace()) { + case AddressSpace::StorageBuffer: + ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_variable_pointers")); + requireSPIRVCapability(SpvCapabilityVariablePointersStorageBuffer); + break; + case AddressSpace::GroupShared: ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_variable_pointers")); requireSPIRVCapability(SpvCapabilityVariablePointers); + break; } } } diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index a46b20b5f..011d1fec9 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -202,7 +202,7 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor( { auto loweredType = lowerDiffPairType(builder, basePtrType->getValueType()); - pairType = builder->getPtrType(kIROp_PtrType, (IRType*)loweredType); + pairType = builder->getPtrType((IRType*)loweredType); } else { diff --git a/source/slang/slang-ir-explicit-global-context.cpp b/source/slang/slang-ir-explicit-global-context.cpp index 9fae4ce15..dd07db883 100644 --- a/source/slang/slang-ir-explicit-global-context.cpp +++ b/source/slang/slang-ir-explicit-global-context.cpp @@ -316,8 +316,7 @@ struct IntroduceExplicitGlobalContextPass // The context will usually be passed around by pointer, // so we get and cache that pointer type up front. // - m_contextStructPtrType = - builder.getPtrType(kIROp_PtrType, m_contextStructType, getAddressSpaceOfLocal()); + m_contextStructPtrType = builder.getPtrType(m_contextStructType, getAddressSpaceOfLocal()); // The first step will be to create fields in the `KernelContext` @@ -630,7 +629,7 @@ struct IntroduceExplicitGlobalContextPass auto ptrType = getGlobalVarPtrType(globalVar); if (fieldInfo.needDereference) - ptrType = builder.getPtrType(kIROp_PtrType, ptrType, getAddressSpaceOfLocal()); + ptrType = builder.getPtrType(ptrType, getAddressSpaceOfLocal()); // We then iterate over the uses of the variable, // being careful to defend against the use/def information diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp index 65b997195..203a610dd 100644 --- a/source/slang/slang-ir-glsl-legalize.cpp +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -1466,7 +1466,11 @@ ScalarizedVal createSimpleGLSLGlobalVarying( // Set the array size to 0, to mean it is unsized auto arrayType = builder->getArrayType(type, 0); - IRType* paramType = builder->getPtrType(ptrOpCode, arrayType, addrSpace); + auto accessQualifier = AccessQualifier::ReadWrite; + if (kind == LayoutResourceKind::VaryingInput) + accessQualifier = AccessQualifier::Read; + IRType* paramType = + builder->getPtrType(ptrOpCode, arrayType, accessQualifier, addrSpace); auto globalParam = addGlobalParam(builder->getModule(), paramType); moveValueBefore(globalParam, builder->getFunc()); @@ -2558,7 +2562,7 @@ static void consolidateParameters(GLSLLegalizationContext* context, List<IRParam // Create a global variable to hold the consolidated struct consolidatedVar = builder->createGlobalVar(structType); - auto ptrType = builder->getPtrType(kIROp_PtrType, structType, AddressSpace::IncomingRayPayload); + auto ptrType = builder->getPtrType(structType, AddressSpace::IncomingRayPayload); consolidatedVar->setFullType(ptrType); consolidatedVar->moveToEnd(); @@ -3088,7 +3092,8 @@ IRInst* getOrCreatePerVertexInputArray(GLSLLegalizationContext* context, IRInst* auto arrayType = builder.getArrayType( tryGetPointedToType(&builder, inputVertexAttr->getDataType()), builder.getIntValue(builder.getIntType(), 3)); - arrayInst = builder.createGlobalParam(builder.getPtrType(arrayType, AddressSpace::Input)); + arrayInst = builder.createGlobalParam( + builder.getPtrType(arrayType, AccessQualifier::Read, AddressSpace::Input)); context->mapVertexInputToPerVertexArray[inputVertexAttr] = arrayInst; builder.addDecoration(arrayInst, kIROp_PerVertexDecoration); @@ -4301,10 +4306,7 @@ void legalizeEntryPointForGLSL( // Re-add ptr if there was one on the input if (ptrType) { - sizedArrayType = builder.getPtrType( - ptrType->getOp(), - sizedArrayType, - ptrType->getAddressSpace()); + sizedArrayType = builder.getPtrType(sizedArrayType, ptrType); } // Change the globals type diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index c8604f4fa..9208e546c 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3721,6 +3721,7 @@ public: IRGenericKind* getGenericKind(); IRPtrType* getPtrType(IRType* valueType); + IRPtrTypeBase* getPtrType(IROp op, IRType* valueType); // Form a ptr type to `valueType` using the same opcode and address space as `ptrWithAddrSpace`. IRPtrTypeBase* getPtrTypeWithAddressSpace(IRType* valueType, IRPtrTypeBase* ptrWithAddrSpace); @@ -3728,14 +3729,41 @@ public: IROutType* getOutType(IRType* valueType); IRInOutType* getInOutType(IRType* valueType); IRRefType* getRefType(IRType* valueType, AddressSpace addrSpace); - IRConstRefType* getConstRefType(IRType* valueType); IRConstRefType* getConstRefType(IRType* valueType, AddressSpace addrSpace); - IRPtrTypeBase* getPtrType(IROp op, IRType* valueType); - IRPtrType* getPtrType(IROp op, IRType* valueType, AddressSpace addressSpace); - IRPtrType* getPtrType(IROp op, IRType* valueType, IRInst* addressSpace); + IRPtrType* getPtrType( + IROp op, + IRType* valueType, + AccessQualifier accessQualifier, + AddressSpace addressSpace); + IRPtrType* getPtrType( + IROp op, + IRType* valueType, + IRInst* accessQualifier, + IRInst* addressSpace); + IRPtrType* getPtrType(IROp op, IRType* valueType, AddressSpace addressSpace) + { + return getPtrType(op, valueType, AccessQualifier::ReadWrite, addressSpace); + } + IRPtrType* getPtrType( + IRType* valueType, + AccessQualifier accessQualifier, + AddressSpace addressSpace) + { + return getPtrType(kIROp_PtrType, valueType, accessQualifier, addressSpace); + } IRPtrType* getPtrType(IRType* valueType, AddressSpace addressSpace) { - return getPtrType(kIROp_PtrType, valueType, addressSpace); + return getPtrType(valueType, AccessQualifier::ReadWrite, addressSpace); + } + // Copies the op-type of the oldPtrType, access-qualifier and address-space. + // Does not reuse the same `inst` for access-qualifier and address-space. + IRPtrTypeBase* getPtrType(IRType* valueType, IRPtrTypeBase* oldPtrType) + { + return getPtrType( + oldPtrType->getOp(), + valueType, + oldPtrType->getAccessQualifier(), + oldPtrType->getAddressSpace()); } IRTextureTypeBase* getTextureType( diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index 9cf8d7b5f..9363bc882 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -3699,6 +3699,7 @@ static LegalVal legalizeGlobalVar(IRTypeLegalizationContext* context, IRGlobalVa irGlobalVar, context->builder->getPtrType( legalValueType.getSimple(), + varPtrType ? varPtrType->getAccessQualifier() : AccessQualifier::ReadWrite, varPtrType ? varPtrType->getAddressSpace() : AddressSpace::Global)); return LegalVal::simple(irGlobalVar); diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index 621c7a55e..062330836 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -3597,10 +3597,8 @@ protected: IRPtrTypeBase* type = as<IRPtrTypeBase>(param->getDataType()); - const auto annotatedPayloadType = builder.getPtrType( - kIROp_ConstRefType, - type->getValueType(), - AddressSpace::MetalObjectData); + const auto annotatedPayloadType = + builder.getConstRefType(type->getValueType(), AddressSpace::MetalObjectData); param->setFullType(annotatedPayloadType); } diff --git a/source/slang/slang-ir-specialize-address-space.cpp b/source/slang/slang-ir-specialize-address-space.cpp index 2bc1de775..29f1ec516 100644 --- a/source/slang/slang-ir-specialize-address-space.cpp +++ b/source/slang/slang-ir-specialize-address-space.cpp @@ -103,8 +103,11 @@ struct AddressSpaceContext : public AddressSpaceSpecializationContext if (ptrType) { auto paramAddrSpace = key.getArgAddrSpaces()[paramIndex]; - auto newParamType = - builder.getPtrType(ptrType->getOp(), ptrType->getValueType(), paramAddrSpace); + auto newParamType = builder.getPtrType( + ptrType->getOp(), + ptrType->getValueType(), + ptrType->getAccessQualifier(), + paramAddrSpace); param->setFullType(newParamType); mapInstToAddrSpace[param] = paramAddrSpace; } @@ -310,6 +313,7 @@ struct AddressSpaceContext : public AddressSpaceSpecializationContext auto newResultType = builder.getPtrType( ptrResultType->getOp(), ptrResultType->getValueType(), + ptrResultType->getAccessQualifier(), addrSpace); fixUpFuncType(func, newResultType); retValAddrSpaceChanged = true; @@ -349,8 +353,11 @@ struct AddressSpaceContext : public AddressSpaceSpecializationContext if (ptrType->getAddressSpace() != addrSpace) { IRBuilder builder(inst); - auto newType = - builder.getPtrType(ptrType->getOp(), ptrType->getValueType(), addrSpace); + auto newType = builder.getPtrType( + ptrType->getOp(), + ptrType->getValueType(), + ptrType->getAccessQualifier(), + addrSpace); setDataType(inst, newType); } } diff --git a/source/slang/slang-ir-specialize-function-call.cpp b/source/slang/slang-ir-specialize-function-call.cpp index c03e644de..7c82891a6 100644 --- a/source/slang/slang-ir-specialize-function-call.cpp +++ b/source/slang/slang-ir-specialize-function-call.cpp @@ -662,12 +662,12 @@ struct FunctionParameterSpecializationContext case kIROp_OutType: case kIROp_RefType: case kIROp_ConstRefType: - argType = as<IRPtrTypeBase>(argType)->getValueType(); - resultType = getBuilder()->getPtrType( - paramType->getOp(), - argType, - as<IRPtrTypeBase>(paramType)->getAddressSpace()); - break; + { + auto ptrParamType = as<IRPtrTypeBase>(paramType); + argType = as<IRPtrTypeBase>(argType)->getValueType(); + resultType = getBuilder()->getPtrType(argType, ptrParamType); + break; + } } if (auto rate = paramType->getRate()) { diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index 9433b560b..2c4bd11cc 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -249,7 +249,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase { builder.setInsertBefore(use->getUser()); auto addr = builder.emitFieldAddress( - builder.getPtrType(kIROp_PtrType, innerType, AddressSpace::Uniform), + builder.getPtrType(innerType, AccessQualifier::Read, AddressSpace::Uniform), cbParamInst, key); use->set(addr); @@ -291,12 +291,9 @@ struct SPIRVLegalizationContext : public SourceEmitterBase auto basePtrType = as<IRPtrTypeBase>(addr->getDataType()); IRType* ptrType = nullptr; if (basePtrType->hasAddressSpace()) - ptrType = builder.getPtrType( - kIROp_PtrType, - user->getDataType(), - basePtrType->getAddressSpace()); + ptrType = builder.getPtrType(user->getDataType(), basePtrType); else - ptrType = builder.getPtrType(kIROp_PtrType, user->getDataType()); + ptrType = builder.getPtrType(user->getDataType()); IRInst* subAddr = nullptr; if (user->getOp() == kIROp_GetElement) subAddr = builder.emitElementAddress( @@ -443,6 +440,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase return; } + AccessQualifier access = AccessQualifier::ReadWrite; // Opaque resource handles can't be in Uniform for Vulkan, if they are // placed here then put them in UniformConstant instead if (isSpirvUniformConstantType(inst->getDataType())) @@ -518,7 +516,10 @@ struct SPIRVLegalizationContext : public SourceEmitterBase // structured buffers in GLSL should be annotated as ReadOnly if (as<IRHLSLStructuredBufferType>(structuredBufferType)) + { + access = AccessQualifier::Read; memoryFlags = MemoryQualifierSetModifier::Flags::kReadOnly; + } if (as<IRHLSLRasterizerOrderedStructuredBufferType>(structuredBufferType)) memoryFlags = MemoryQualifierSetModifier::Flags::kRasterizerOrdered; @@ -555,7 +556,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase // Make a pointer type of storageClass. builder.setInsertBefore(inst); - ptrType = builder.getPtrType(kIROp_PtrType, innerType, addressSpace); + ptrType = builder.getPtrType(innerType, access, addressSpace); inst->setFullType(ptrType); if (needLoad) { @@ -578,7 +579,10 @@ struct SPIRVLegalizationContext : public SourceEmitterBase IRBuilder builder(getElement); builder.setInsertBefore(user); auto newAddr = builder.emitElementAddress( - builder.getPtrType(kIROp_PtrType, innerElementType, addressSpace), + builder.getPtrType( + innerElementType, + ptrType->getAccessQualifier(), + addressSpace), inst, getElement->getIndex()); user->replaceUsesWith(newAddr); @@ -714,6 +718,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase auto newPtrType = builder.getPtrType( oldPtrType->getOp(), oldPtrType->getValueType(), + oldPtrType->getAccessQualifier(), AddressSpace::Function); inst->setFullType(newPtrType); addUsersToWorkList(inst); @@ -735,9 +740,12 @@ struct SPIRVLegalizationContext : public SourceEmitterBase if (block == func->getFirstBlock()) { - // A pointer typed function parameter should always be in the storage buffer address - // space. - addressSpace = AddressSpace::UserPointer; + // A pointer typed function parameter is in the storage buffer address + // space or groupshared. + if (as<IRGroupSharedRate>(inst->getRate())) + addressSpace = AddressSpace::GroupShared; + else + addressSpace = AddressSpace::UserPointer; } else { @@ -765,6 +773,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase auto newPtrType = builder.getPtrType( oldPtrType->getOp(), oldPtrType->getValueType(), + oldPtrType->getAccessQualifier(), AddressSpace::UserPointer); inst->setFullType(newPtrType); addUsersToWorkList(inst); @@ -785,10 +794,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase IRBuilder builder(inst); builder.setInsertBefore(inst); IRType* newPtrType = oldPtrType->hasAddressSpace() - ? builder.getPtrType( - oldPtrType->getOp(), - newPtrValueType, - oldPtrType->getAddressSpace()) + ? builder.getPtrType(newPtrValueType, oldPtrType) : builder.getPtrType(oldPtrType->getOp(), newPtrValueType); inst->setFullType(newPtrType); } @@ -839,8 +845,11 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } IRBuilder builder(m_sharedContext->m_irModule); builder.setInsertBefore(inst); - auto newPtrType = - builder.getPtrType(oldPtrType->getOp(), oldPtrType->getValueType(), addressSpace); + auto newPtrType = builder.getPtrType( + oldPtrType->getOp(), + oldPtrType->getValueType(), + oldPtrType->getAccessQualifier(), + addressSpace); inst->setFullType(newPtrType); addUsersToWorkList(inst); return; @@ -1022,6 +1031,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase auto newPtrType = builder.getPtrType( oldResultType->getOp(), oldResultType->getValueType(), + ptrType->getAccessQualifier(), ptrType->getAddressSpace()); IRInst* args[2] = {base, index}; auto newInst = builder.emitIntrinsicInst(newPtrType, gepInst->getOp(), 2, args); @@ -1075,6 +1085,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase auto newResultType = builder.getPtrType( resultPtrType->getOp(), resultPtrType->getValueType(), + ptrOperandType->getAccessQualifier(), ptrOperandType->getAddressSpace()); auto newInst = builder.replaceOperand(&offsetPtrInst->typeUse, newResultType); addUsersToWorkList(newInst); @@ -1095,8 +1106,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase builder.setInsertBefore(loadInst); IRInst* args[] = {sb, index}; auto addrInst = builder.emitIntrinsicInst( - builder - .getPtrType(kIROp_PtrType, loadInst->getFullType(), getStorageBufferAddressSpace()), + builder.getPtrType(loadInst->getFullType(), getStorageBufferAddressSpace()), kIROp_RWStructuredBufferGetElementPtr, 2, args); @@ -1115,7 +1125,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase builder.setInsertBefore(storeInst); IRInst* args[] = {sb, index}; auto addrInst = builder.emitIntrinsicInst( - builder.getPtrType(kIROp_PtrType, value->getFullType(), getStorageBufferAddressSpace()), + builder.getPtrType(value->getFullType(), getStorageBufferAddressSpace()), kIROp_RWStructuredBufferGetElementPtr, 2, args); @@ -1168,6 +1178,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase auto newPtrType = builder.getPtrType( oldResultType->getOp(), newValueType, + ptrType->getAccessQualifier(), ptrType->getAddressSpace()); auto newInst = builder.emitFieldAddress(newPtrType, inst->getBase(), inst->getField()); @@ -2172,8 +2183,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } // Update the global param's type to use the wrapper struct - auto newPtrType = - builder.getPtrType(ptrType->getOp(), wrapperStruct, ptrType->getAddressSpace()); + auto newPtrType = builder.getPtrType(wrapperStruct, ptrType); globalParam->setFullType(newPtrType); // Traverse all uses of the global param and insert a FieldAddress to access the @@ -2184,7 +2194,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase { builder.setInsertBefore(use->getUser()); auto addr = builder.emitFieldAddress( - builder.getPtrType(kIROp_PtrType, structType, ptrType->getAddressSpace()), + builder.getPtrType(structType, ptrType), globalParam, key); use->set(addr); @@ -2246,11 +2256,16 @@ struct SPIRVLegalizationContext : public SourceEmitterBase for (auto t : instsToProcess) { auto lowered = lowerStructuredBufferType(t); + + AccessQualifier accessQualifier = AccessQualifier::ReadWrite; + if (as<IRHLSLStructuredBufferType>(t)) + accessQualifier = AccessQualifier::Read; + IRBuilder builder(t); builder.setInsertBefore(t); t->replaceUsesWith(builder.getPtrType( - kIROp_PtrType, lowered.structType, + accessQualifier, getStorageBufferAddressSpace())); } for (auto t : textureFootprintTypes) diff --git a/source/slang/slang-ir-translate-global-varying-var.cpp b/source/slang/slang-ir-translate-global-varying-var.cpp index 57b277d25..c899de653 100644 --- a/source/slang/slang-ir-translate-global-varying-var.cpp +++ b/source/slang/slang-ir-translate-global-varying-var.cpp @@ -220,8 +220,8 @@ struct GlobalVarTranslationContext input->transferDecorationsTo(key); // Emit a new param here to represent the global input var. - auto inputParam = builder.emitParam( - builder.getPtrType(kIROp_ConstRefType, inputType, AddressSpace::Input)); + auto inputParam = + builder.emitParam(builder.getConstRefType(inputType, AddressSpace::Input)); // Copy the global input vars original decorations onto the new param. // We need to do this to ensure that we can do things like get system diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 6b9273c15..ab59112f3 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2935,17 +2935,13 @@ IRInOutType* IRBuilder::getInOutType(IRType* valueType) IRRefType* IRBuilder::getRefType(IRType* valueType, AddressSpace addrSpace) { - return (IRRefType*)getPtrType(kIROp_RefType, valueType, addrSpace); -} - -IRConstRefType* IRBuilder::getConstRefType(IRType* valueType) -{ - return (IRConstRefType*)getPtrType(kIROp_ConstRefType, valueType); + return (IRRefType*)getPtrType(kIROp_RefType, valueType, AccessQualifier::ReadWrite, addrSpace); } IRConstRefType* IRBuilder::getConstRefType(IRType* valueType, AddressSpace addrSpace) { - return (IRConstRefType*)getPtrType(kIROp_ConstRefType, valueType, addrSpace); + return ( + IRConstRefType*)getPtrType(kIROp_ConstRefType, valueType, AccessQualifier::Read, addrSpace); } IRSPIRVLiteralType* IRBuilder::getSPIRVLiteralType(IRType* type) @@ -2965,23 +2961,35 @@ IRPtrTypeBase* IRBuilder::getPtrTypeWithAddressSpace( IRPtrTypeBase* ptrWithAddrSpace) { if (ptrWithAddrSpace->hasAddressSpace()) - return (IRPtrTypeBase*) - getPtrType(ptrWithAddrSpace->getOp(), valueType, ptrWithAddrSpace->getAddressSpace()); + return (IRPtrTypeBase*)getPtrType( + ptrWithAddrSpace->getOp(), + valueType, + ptrWithAddrSpace->getAccessQualifier(), + ptrWithAddrSpace->getAddressSpace()); return (IRPtrTypeBase*)getPtrType(ptrWithAddrSpace->getOp(), valueType); } -IRPtrType* IRBuilder::getPtrType(IROp op, IRType* valueType, AddressSpace addressSpace) +IRPtrType* IRBuilder::getPtrType( + IROp op, + IRType* valueType, + AccessQualifier accessQualifier, + AddressSpace addressSpace) { return (IRPtrType*)getPtrType( op, valueType, + getIntValue(getUInt64Type(), static_cast<IRIntegerValue>(accessQualifier)), getIntValue(getUInt64Type(), static_cast<IRIntegerValue>(addressSpace))); } -IRPtrType* IRBuilder::getPtrType(IROp op, IRType* valueType, IRInst* addressSpace) +IRPtrType* IRBuilder::getPtrType( + IROp op, + IRType* valueType, + IRInst* accessQualifier, + IRInst* addressSpace) { - IRInst* operands[] = {valueType, addressSpace}; - return (IRPtrType*)getType(op, addressSpace ? 2 : 1, operands); + IRInst* operands[] = {valueType, accessQualifier, addressSpace}; + return (IRPtrType*)getType(op, addressSpace ? 3 : 1, operands); } IRTextureTypeBase* IRBuilder::getTextureType( @@ -4822,7 +4830,7 @@ IRGlobalVar* IRBuilder::createGlobalVar(IRType* valueType) IRGlobalVar* IRBuilder::createGlobalVar(IRType* valueType, AddressSpace addressSpace) { - auto ptrType = getPtrType(kIROp_PtrType, valueType, addressSpace); + auto ptrType = getPtrType(valueType, addressSpace); IRGlobalVar* globalVar = createInst<IRGlobalVar>(this, kIROp_GlobalVar, ptrType); _maybeSetSourceLoc(globalVar); addGlobalValue(this, globalVar); @@ -5079,7 +5087,7 @@ IRVar* IRBuilder::emitVar(IRType* type) IRVar* IRBuilder::emitVar(IRType* type, AddressSpace addressSpace) { - auto allocatedType = getPtrType(kIROp_PtrType, type, addressSpace); + auto allocatedType = getPtrType(type, addressSpace); auto inst = createInst<IRVar>(this, kIROp_Var, allocatedType); addInst(inst); return inst; @@ -5308,6 +5316,7 @@ IRType* maybePropagateAddressSpace(IRBuilder* builder, IRInst* basePtr, IRType* type = builder->getPtrType( resultPtrType->getOp(), resultPtrType->getValueType(), + basePtrType->getAccessQualifier(), basePtrType->getAddressSpace()); } } @@ -5318,10 +5327,12 @@ IRType* maybePropagateAddressSpace(IRBuilder* builder, IRInst* basePtr, IRType* IRInst* IRBuilder::emitFieldAddress(IRInst* basePtr, IRInst* fieldKey) { AddressSpace addrSpace = AddressSpace::Generic; + AccessQualifier accessQualifier = AccessQualifier::ReadWrite; IRInst* valueType = nullptr; auto basePtrType = unwrapAttributedType(basePtr->getDataType()); if (auto ptrType = as<IRPtrTypeBase>(basePtrType)) { + accessQualifier = ptrType->getAccessQualifier(); addrSpace = ptrType->getAddressSpace(); valueType = ptrType->getValueType(); } @@ -5344,7 +5355,7 @@ IRInst* IRBuilder::emitFieldAddress(IRInst* basePtr, IRInst* fieldKey) } } SLANG_RELEASE_ASSERT(resultType); - return emitFieldAddress(getPtrType(kIROp_PtrType, resultType, addrSpace), basePtr, fieldKey); + return emitFieldAddress(getPtrType(resultType, accessQualifier, addrSpace), basePtr, fieldKey); } IRInst* IRBuilder::emitFieldAddress(IRType* type, IRInst* base, IRInst* field) @@ -5448,10 +5459,12 @@ IRInst* IRBuilder::emitElementAddress(IRInst* basePtr, IRIntegerValue index) IRInst* IRBuilder::emitElementAddress(IRInst* basePtr, IRInst* index) { AddressSpace addrSpace = AddressSpace::Generic; + AccessQualifier accessQualifier = AccessQualifier::ReadWrite; IRInst* valueType = nullptr; auto basePtrType = unwrapAttributedType(basePtr->getDataType()); if (auto ptrType = as<IRPtrTypeBase>(basePtrType)) { + accessQualifier = ptrType->getAccessQualifier(); addrSpace = ptrType->getAddressSpace(); valueType = ptrType->getValueType(); } @@ -5500,7 +5513,7 @@ IRInst* IRBuilder::emitElementAddress(IRInst* basePtr, IRInst* index) auto inst = createInst<IRGetElementPtr>( this, kIROp_GetElementPtr, - getPtrType(kIROp_PtrType, type, addrSpace), + getPtrType(type, accessQualifier, addrSpace), basePtr, index); diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index ad14edf21..d8fe51ddf 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1681,15 +1681,22 @@ struct IRPtrTypeBase : IRType FIDDLE(baseInst()) IRType* getValueType() { return (IRType*)getOperand(0); } + AccessQualifier getAccessQualifier() + { + return getOperandCount() > 1 + ? (AccessQualifier) static_cast<IRIntLit*>(getOperand(1))->getValue() + : AccessQualifier::ReadWrite; + } + bool hasAddressSpace() { - return getOperandCount() > 1 && getAddressSpace() != AddressSpace::Generic; + return getOperandCount() > 2 && getAddressSpace() != AddressSpace::Generic; } AddressSpace getAddressSpace() { - return getOperandCount() > 1 - ? (AddressSpace) static_cast<IRIntLit*>(getOperand(1))->getValue() + return getOperandCount() > 2 + ? (AddressSpace) static_cast<IRIntLit*>(getOperand(2))->getValue() : AddressSpace::Generic; } }; diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 8e1f85f8e..4df778ee6 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -2067,7 +2067,14 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower auto astValueType = type->getValueType(); IRType* irValueType = lowerType(context, astValueType); + IRInst* accessQualifier = nullptr; IRInst* addrSpace = nullptr; + + if (auto astAccessQualifier = type->getAccessQualifier()) + { + accessQualifier = getSimpleVal(context, lowerVal(context, astAccessQualifier)); + } + if (auto astAddrSpace = type->getAddressSpace()) { addrSpace = getSimpleVal(context, lowerVal(context, astAddrSpace)); @@ -2078,7 +2085,8 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower getBuilder()->getUInt64Type(), (IRIntegerValue)AddressSpace::Generic); } - return getBuilder()->getPtrType(kIROp_PtrType, irValueType, addrSpace); + + return getBuilder()->getPtrType(kIROp_PtrType, irValueType, accessQualifier, addrSpace); } IRType* visitDeclRefType(DeclRefType* type) @@ -3437,7 +3445,6 @@ void _lowerFuncDeclBaseTypeInfo( auto& parameterLists = outInfo.parameterLists; collectParameterLists( context, - declRef, ¶meterLists, kParameterListCollectMode_Default, @@ -3469,7 +3476,7 @@ void _lowerFuncDeclBaseTypeInfo( irParamType = builder->getRefType(irParamType, AddressSpace::Generic); break; case kParameterDirection_ConstRef: - irParamType = builder->getConstRefType(irParamType); + irParamType = builder->getConstRefType(irParamType, AddressSpace::Generic); break; default: SLANG_UNEXPECTED("unknown parameter direction"); @@ -4157,6 +4164,39 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo> ASTBuilder* getASTBuilder() { return context->astBuilder; } LoweredValInfo lowerSubExpr(Expr* expr) { return sharedLoweringContext.lowerSubExpr(expr); } + LoweredValInfo visitAddressOfExpr(AddressOfExpr* expr) + { + auto loweredType = lowerType(context, expr->type); + auto baseVal = lowerLValueExpr(context, expr->arg); + auto ptr = tryGetAddress(context, baseVal, TryGetAddressMode::Aggressive); + + switch (ptr.flavor) + { + case LoweredValInfo::Flavor::Ptr: + { + // TODO: This is a hack. We should just be returning `ptr`. We do not do this since + // `ptr` may have the wrong address space. This happens since when lowering-to-ir we + // don't check what addres-space info we should be using for variables we create. + // example: `groupshared int ptr` ==> lower-to-ir lowers as default address-space + // with groupshared-rate. + // + // We need to emit a temporary variable (and cannot emit a cast) since `operator*` + // has its own hacks and is an incorrect implementation of its own. To elaborate, + // `operator*` is defined as `__intrinsic_op(0)`, which means "pass arguments + // through a function `in`, then set as result". This is an issue since this means + // that our function (which should be returning a `ref`) may in fact, not be + // returning a `ref` but instead be loading via the `in` parameter and generating a + // non-pointer result. + auto irVar = context->irBuilder->emitVar(loweredType); + context->irBuilder->emitStore(irVar, ptr.val); + return LoweredValInfo::ptr(irVar); + } + default: + SLANG_UNIMPLEMENTED_X("cannot get address of __getAddress(...) argument"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + } + LoweredValInfo visitIncompleteExpr(IncompleteExpr*) { SLANG_UNEXPECTED("a valid ast should not contain an IncompleteExpr."); diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp index ea620ebb2..7a5665905 100644 --- a/source/slang/slang-mangle.cpp +++ b/source/slang/slang-mangle.cpp @@ -191,6 +191,16 @@ void emitBaseType(ManglingContext* context, BaseType baseType) case BaseType::IntPtr: emitRaw(context, "ip"); break; + case BaseType::AddressSpace: + emitRaw(context, "as"); + break; + case BaseType::AccessQualifier: + emitRaw(context, "aq"); + break; + case BaseType::MemoryScope: + emitRaw(context, "mem"); + break; + default: SLANG_UNEXPECTED("unimplemented case in base type mangling"); break; diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 1302975df..e71b6162c 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -7026,7 +7026,6 @@ static NodeBase* parseSizeOfExpr(Parser* parser, void* /*userData*/) static NodeBase* parseAlignOfExpr(Parser* parser, void* /*userData*/) { - // We could have a type or a variable or an expression AlignOfExpr* alignOfExpr = parser->astBuilder->create<AlignOfExpr>(); parser->ReadMatchingToken(TokenType::LParent); @@ -7058,6 +7057,17 @@ static NodeBase* parseCountOfExpr(Parser* parser, void* /*userData*/) return countOfExpr; } +static NodeBase* parseAddressOfExpr(Parser* parser, void* /*userData*/) +{ + // We could have a type or a variable or an expression + AddressOfExpr* addressOfExpr = parser->astBuilder->create<AddressOfExpr>(); + + parser->ReadMatchingToken(TokenType::LParent); + addressOfExpr->arg = parser->ParseExpression(); + parser->ReadMatchingToken(TokenType::RParent); + return addressOfExpr; +} + static NodeBase* parseTryExpr(Parser* parser, void* /*userData*/) { auto tryExpr = parser->astBuilder->create<TryExpr>(); @@ -9648,6 +9658,7 @@ static const SyntaxParseInfo g_parseSyntaxEntries[] = { _makeParseExpr("sizeof", parseSizeOfExpr), _makeParseExpr("alignof", parseAlignOfExpr), _makeParseExpr("countof", parseCountOfExpr), + _makeParseExpr("__getAddress", parseAddressOfExpr), }; ConstArrayView<SyntaxParseInfo> getSyntaxParseInfos() diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 67d562f0f..d9eb884f0 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -927,7 +927,7 @@ FuncType* getFuncType(ASTBuilder* astBuilder, DeclRef<CallableDecl> const& declR } if (paramDecl->findModifier<RefModifier>()) { - paramType = astBuilder->getRefType(paramType, AddressSpace::Generic); + paramType = astBuilder->getRefType(paramType); } else if (paramDecl->findModifier<ConstRefModifier>()) { diff --git a/source/slang/slang-type-system-shared.h b/source/slang/slang-type-system-shared.h index d7bd43122..3390c3b80 100644 --- a/source/slang/slang-type-system-shared.h +++ b/source/slang/slang-type-system-shared.h @@ -22,6 +22,10 @@ namespace Slang X(Char) \ X(IntPtr) \ X(UIntPtr) \ + X(CountOfPrimitives) \ + X(AddressSpace) \ + X(MemoryScope) \ + X(AccessQualifier) \ /* end */ enum class BaseType @@ -114,6 +118,26 @@ enum class AddressSpace : uint64_t // Default address space for a user-defined pointer UserPointer = 0x100000001ULL, }; + +// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_scope_id +// must be 32 bit to match SPIR-V +enum class MemoryScope : int32_t +{ + CrossDevice = 0, + Device = 1, + Workgroup = 2, + Subgroup = 3, + Invocation = 4, + QueueFamily = 5, + ShaderCall = 6, +}; + +enum class AccessQualifier : uint64_t +{ + ReadWrite = 0, + Read = 1, +}; + } // namespace Slang #endif diff --git a/source/slang/slang.natvis b/source/slang/slang.natvis index edd34f456..97f4689f1 100644 --- a/source/slang/slang.natvis +++ b/source/slang/slang.natvis @@ -260,7 +260,8 @@ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::AndTypeExpr">(Slang::AndTypeExpr*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ModifiedTypeExpr">(Slang::ModifiedTypeExpr*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::PointerTypeExpr">(Slang::PointerTypeExpr*)&astNodeType</ExpandedItem> - <Item Name="[type]">type</Item> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::AddressOfExpr">(Slang::AddressOfExpr*)&astNodeType</ExpandedItem> + <Item Name="[type]">type</Item> <Item Name="[Expr]">(Slang::Expr*)this,!</Item> </Expand> </Type> @@ -484,11 +485,12 @@ </Expand> </Type> <Type Name="Slang::ValNodeOperand"> - <DisplayString Optional="true" Condition="kind==Slang::ValNodeOperandKind::ConstantValue">Const({values.intOperand})#{_debugUID}</DisplayString> + <DisplayString Condition="kind==Slang::ValNodeOperandKind::ConstantValue">ConstantValue ({this->values.intOperand}) #{((Val*)this)->_debugUID}</DisplayString> <DisplayString Condition="kind==Slang::ValNodeOperandKind::ValNode">{*(Val*)values.nodeOperand}</DisplayString> <DisplayString>{values.nodeOperand}</DisplayString> <Expand> - <ExpandedItem Condition="kind==Slang::ValNodeOperandKind::ValNode">*(Val*)values.nodeOperand</ExpandedItem> + <ExpandedItem Condition="kind==Slang::ValNodeOperandKind::ConstantValue">values</ExpandedItem> + <ExpandedItem Condition="kind==Slang::ValNodeOperandKind::ValNode">*(Val*)values.nodeOperand</ExpandedItem> <ExpandedItem Condition="kind==Slang::ValNodeOperandKind::ASTNode">*(Decl*)values.nodeOperand</ExpandedItem> </Expand> </Type> |
