summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArielG-NV <159081215+ArielG-NV@users.noreply.github.com>2025-08-29 15:52:34 -0700
committerGitHub <noreply@github.com>2025-08-29 22:52:34 +0000
commit7758625d3fea67e55e98e7e4103d56c9918365be (patch)
tree2ed40aeb4d16262866e5540dad1a519951b5f772
parent450ef7934c1adfdf4a3a3c72967de3c5798a020d (diff)
[CBP] Pointer frontend changes + groupshared pointer support (#7848)
Resolves #7628 Resolves: #8197 Primary Goals: 1. Add `Access` to pointer 2. AddressSpace::GroupShared support for pointers (SPIR-V) 3. Add `__getAddress()` to replace `&` * `&` is not updated to `require(cpu)` since slangpy uses `&`. This means we must: (1) merge PR; (2) replace `&` with `__getAddress()`; (3) add `require(cpu)` to `&` Changes: * Added to `Ptr` the `Access` generic argument & logic (for `Access::Read`). * Moved the generic argument `AddressSpace` from `Ptr` to the end of the type. * Added pointer casting support between any `Ptr` as long as the `AddressSpace` is the same * Disallow globallycoherent T* and coherent T* * Disallow const T*, T const*, and const T* * Fixed .natvis display of `ConstantValue` `ValOperandNode` * Support generic resolution of type-casted integers * Added `VariablePointer` emitting for spirv + other minor logic needed for groupshared pointers Breaking Changes: * Anyone using the `AddressSpace` of `Ptr` will now have to account for the `Access` argument * we disallow various syntax paired with `Ptr` and `T*` --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
-rw-r--r--docs/command-line-slangc-reference.md1
-rw-r--r--source/slang/core.meta.slang180
-rw-r--r--source/slang/hlsl.meta.slang2
-rw-r--r--source/slang/slang-ast-builder.cpp39
-rw-r--r--source/slang/slang-ast-builder.h16
-rw-r--r--source/slang/slang-ast-expr.h7
-rw-r--r--source/slang/slang-ast-iterator.h1
-rw-r--r--source/slang/slang-ast-modifier.h1
-rw-r--r--source/slang/slang-ast-print.cpp9
-rw-r--r--source/slang/slang-ast-support-types.cpp4
-rw-r--r--source/slang/slang-ast-type.cpp59
-rw-r--r--source/slang/slang-ast-type.h2
-rw-r--r--source/slang/slang-ast-val.cpp4
-rw-r--r--source/slang/slang-base-type-info.cpp8
-rw-r--r--source/slang/slang-base-type-info.h2
-rw-r--r--source/slang/slang-capabilities.capdef4
-rw-r--r--source/slang/slang-check-constraint.cpp2
-rw-r--r--source/slang/slang-check-conversion.cpp7
-rw-r--r--source/slang/slang-check-decl.cpp18
-rw-r--r--source/slang/slang-check-expr.cpp172
-rw-r--r--source/slang/slang-check-impl.h2
-rw-r--r--source/slang/slang-check-modifier.cpp12
-rw-r--r--source/slang/slang-check-overload.cpp17
-rw-r--r--source/slang/slang-check-shader.cpp2
-rw-r--r--source/slang/slang-diagnostic-defs.h39
-rw-r--r--source/slang/slang-emit-spirv.cpp24
-rw-r--r--source/slang/slang-ir-autodiff.cpp2
-rw-r--r--source/slang/slang-ir-explicit-global-context.cpp5
-rw-r--r--source/slang/slang-ir-glsl-legalize.cpp16
-rw-r--r--source/slang/slang-ir-insts.h38
-rw-r--r--source/slang/slang-ir-legalize-types.cpp1
-rw-r--r--source/slang/slang-ir-legalize-varying-params.cpp6
-rw-r--r--source/slang/slang-ir-specialize-address-space.cpp15
-rw-r--r--source/slang/slang-ir-specialize-function-call.cpp12
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp63
-rw-r--r--source/slang/slang-ir-translate-global-varying-var.cpp4
-rw-r--r--source/slang/slang-ir.cpp47
-rw-r--r--source/slang/slang-ir.h13
-rw-r--r--source/slang/slang-lower-to-ir.cpp46
-rw-r--r--source/slang/slang-mangle.cpp10
-rw-r--r--source/slang/slang-parser.cpp13
-rw-r--r--source/slang/slang-syntax.cpp2
-rw-r--r--source/slang/slang-type-system-shared.h24
-rw-r--r--source/slang/slang.natvis8
-rw-r--r--tests/autodiff/get-offset-ptr.slang40
-rw-r--r--tests/bugs/gh-3601.slang8
-rw-r--r--tests/diagnostics/invalid-constant-pointer-taking.slang16
-rw-r--r--tests/language-feature/bitfield/msvc-repr-mixed.slang13
-rw-r--r--tests/language-feature/capability/address-of.slang17
-rw-r--r--tests/language-feature/pointer/const-ptr-variations.slang40
-rw-r--r--tests/language-feature/pointer/get-address-validation.slang82
-rw-r--r--tests/language-feature/pointer/globallycoherent-ptr.slang20
-rw-r--r--tests/language-feature/pointer/groupshared-ptr-of-device.slang28
-rw-r--r--tests/language-feature/pointer/pointer-access/pointer-access-frontend.slang14
-rw-r--r--tests/language-feature/pointer/pointer-access/read-only-pointer-1.slang41
-rw-r--r--tests/language-feature/pointer/pointer-access/read-only-pointer-2.slang19
-rw-r--r--tests/language-feature/pointer/pointer-casting/pointer-casting-rules.slang51
-rw-r--r--tests/language-feature/pointer/pointer-self-reference.slang10
-rw-r--r--tests/language-feature/pointer/ptr-to-groupshared.slang30
-rw-r--r--tests/spirv/pointer-from-user-guide.slang2
-rw-r--r--tests/spirv/pointer.slang4
-rw-r--r--tests/spirv/ptr-vector-member.slang20
-rw-r--r--tools/gfx/gfx.slang2
63 files changed, 1160 insertions, 256 deletions
diff --git a/docs/command-line-slangc-reference.md b/docs/command-line-slangc-reference.md
index 36947d7b4..3e250df88 100644
--- a/docs/command-line-slangc-reference.md
+++ b/docs/command-line-slangc-reference.md
@@ -1267,6 +1267,7 @@ A capability describes an optional feature that a target may or may not support.
* `any_cpp_target`
* `cpp_cuda`
* `cpp_cuda_spirv`
+* `cpp_cuda_metal_spirv`
* `cuda_spirv`
* `cpp_cuda_glsl_spirv`
* `cpp_cuda_glsl_hlsl`
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 3306403f5..5d2a80c29 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -1278,18 +1278,52 @@ struct __none_t
{
};
+// @hidden: this type is a BaseType since we want it to work with
+// `registerBuiltinDecl`
+__builtin_type($((int)BaseType::AddressSpace))
+enum AddressSpace : uint64_t
+{
+ Device = $((uint64_t)AddressSpace::UserPointer),
+ GroupShared = $((uint64_t)AddressSpace::GroupShared),
+};
+
+// @hidden: this type is a BaseType since we want it to work with
+// `registerBuiltinDecl`
+__builtin_type($((int)BaseType::MemoryScope))
+enum MemoryScope : int32_t
+{
+ CrossDevice = $((int32_t)MemoryScope::CrossDevice),
+ Device = $((int32_t)MemoryScope::Device),
+ Workgroup = $((int32_t)MemoryScope::Workgroup),
+ Subgroup = $((int32_t)MemoryScope::Subgroup),
+ Invocation = $((int32_t)MemoryScope::Invocation),
+ QueueFamily = $((int32_t)MemoryScope::QueueFamily),
+}
+
+// @hidden: this type is a BaseType since we want it to work with
+// `registerBuiltinDecl`
+__builtin_type($((int)BaseType::AccessQualifier))
+enum Access : uint64_t
+{
+ ReadWrite = $((uint64_t)AccessQualifier::ReadWrite),
+ Read = $((uint64_t)AccessQualifier::Read),
+}
+
//@public:
/// Represents a pointer type.
/// @param T The type of the value pointed to.
/// @remarks `T* val` is equivalent to `Ptr<T> val`.
-__generic<T, let addrSpace : uint64_t = $((uint64_t)AddressSpace::UserPointer)ULL>
__magic_type(PtrType)
__intrinsic_type($(kIROp_PtrType))
-struct Ptr
+struct Ptr<
+ T,
+ Access access = Access::ReadWrite,
+ AddressSpace addrSpace = AddressSpace::Device>
{
- __generic<U>
+ // A user is allowed to explicitly cast between any pointer type of
+ // the same address space
__intrinsic_op($(kIROp_BitCast))
- __init(Ptr<U, addrSpace> ptr);
+ __init<U, Access accessOther>(Ptr<U, accessOther, addrSpace> ptr);
__intrinsic_op($(kIROp_CastIntToPtr))
__init(uint64_t val);
@@ -1297,16 +1331,30 @@ struct Ptr
__intrinsic_op($(kIROp_CastIntToPtr))
__init(int64_t val);
+ // By default, getter is not an L value
__generic<TInt : __BuiltinIntegerType>
__subscript(TInt index) -> T
{
- // If a 'Ptr[index]' is referred to by a '__ref', call 'kIROp_GetOffsetPtr(index)'
__intrinsic_op($(kIROp_GetOffsetPtr))
[nonmutating]
ref;
}
};
+extension<T, AddressSpace addrSpace> Ptr<T, Access::ReadWrite, addrSpace>
+{
+ // We have a `ref` accessor if we are ReadWrite. This means only `ReadWrite`
+ // can be used as an L-value.
+ __generic<TInt : __BuiltinIntegerType>
+ __subscript(TInt index) -> Ref<T>
+ {
+ // If a 'Ptr[index]' is referred to by a '__ref', call 'kIROp_GetOffsetPtr(index)'
+ __intrinsic_op($(kIROp_GetOffsetPtr))
+ [nonmutating]
+ ref;
+ }
+}
+
//@hidden:
__intrinsic_op($(kIROp_AlignedAttr))
void __align_attr(int alignment);
@@ -1348,50 +1396,64 @@ void storeAligned<int alignment, T>(T* ptr, T value)
__store_aligned(ptr, value, __align_attr(alignment));
}
+${{{
+ StringBuilder ptrTypeParameterListBuilder;
+ ptrTypeParameterListBuilder << "T, Access access, AddressSpace addrSpace";
+ String ptrTypeParameterList = ptrTypeParameterListBuilder.toString();
+
+ StringBuilder ptrArgListBuilder;
+ ptrArgListBuilder << "T, access, addrSpace";
+ String ptrArgList = ptrArgListBuilder.toString();
+
+ StringBuilder fullPtrTypeBuilder;
+ fullPtrTypeBuilder << "Ptr<" << ptrArgList << ">";
+ String fullPtrType = fullPtrTypeBuilder.toString();
+
+}}}
//@hidden:
__intrinsic_op($(kIROp_Load))
-T __load<T, let addrSpace : uint64_t>(Ptr<T, addrSpace> ptr);
+T __load<$(ptrTypeParameterList)>($(fullPtrType) ptr);
__intrinsic_op($(kIROp_Store))
-void __store<T, let addrSpace : uint64_t>(Ptr<T, addrSpace> ptr, T val);
+void __store<$(ptrTypeParameterList)>($(fullPtrType) ptr, T val);
__intrinsic_op($(kIROp_GetElementPtr))
-Ptr<T, addrSpace> __getElementPtr<T, let addrSpace : uint64_t, TIndex : __BuiltinIntegerType>(Ptr<T, addrSpace> ptr, TIndex index);
+$(fullPtrType) __getElementPtr<$(ptrTypeParameterList), TIndex : __BuiltinIntegerType>($(fullPtrType) ptr, TIndex index);
__intrinsic_op($(kIROp_GetOffsetPtr))
-Ptr<T, addrSpace> __getOffsetPtr<T, let addrSpace : uint64_t, TIndex : __BuiltinIntegerType>(Ptr<T, addrSpace> ptr, TIndex index);
+$(fullPtrType) __getOffsetPtr<$(ptrTypeParameterList), TIndex : __BuiltinIntegerType>($(fullPtrType) ptr, TIndex index);
-__generic<T, let addrSpace : uint64_t>
+__generic<$(ptrTypeParameterList)>
__intrinsic_op($(kIROp_Less))
-bool operator <(Ptr<T, addrSpace> p1, Ptr<T, addrSpace> p2);
+bool operator <($(fullPtrType) p1, $(fullPtrType) p2);
-__generic<T, let addrSpace : uint64_t>
+__generic<$(ptrTypeParameterList)>
__intrinsic_op($(kIROp_Leq))
-bool operator <=(Ptr<T, addrSpace> p1, Ptr<T, addrSpace> p2);
+bool operator <=($(fullPtrType) p1, $(fullPtrType) p2);
-__generic<T, let addrSpace : uint64_t>
+__generic<$(ptrTypeParameterList)>
__intrinsic_op($(kIROp_Greater))
-bool operator>(Ptr<T, addrSpace> p1, Ptr<T, addrSpace> p2);
+bool operator>($(fullPtrType) p1, $(fullPtrType) p2);
-__generic<T, let addrSpace : uint64_t>
+__generic<$(ptrTypeParameterList)>
__intrinsic_op($(kIROp_Geq))
-bool operator >=(Ptr<T, addrSpace> p1, Ptr<T, addrSpace> p2);
+bool operator >=($(fullPtrType) p1, $(fullPtrType) p2);
-__generic<T, let addrSpace : uint64_t>
+__generic<$(ptrTypeParameterList)>
__intrinsic_op($(kIROp_Neq))
-bool operator !=(Ptr<T, addrSpace> p1, Ptr<T, addrSpace> p2);
+bool operator !=($(fullPtrType) p1, $(fullPtrType) p2);
-__generic<T, let addrSpace : uint64_t>
+__generic<$(ptrTypeParameterList)>
__intrinsic_op($(kIROp_Eql))
-bool operator ==(Ptr<T, addrSpace> p1, Ptr<T, addrSpace> p2);
+bool operator ==($(fullPtrType) p1, $(fullPtrType) p2);
//@public:
extension bool : IRangedValue
{
- __generic<T, let addrSpace : uint64_t>
+ __generic<$(ptrTypeParameterList)>
__implicit_conversion($(kConversionCost_PtrToBool))
__intrinsic_op($(kIROp_CastPtrToBool))
- __init(Ptr<T, addrSpace> ptr);
+ __init($(fullPtrType) ptr);
__generic<T : __EnumType>
__implicit_conversion($(kConversionCost_IntegerTruncate))
@@ -1407,9 +1469,9 @@ extension bool : IRangedValue
extension uint64_t : IRangedValue
{
- __generic<T, let addrSpace : uint64_t>
+ __generic<$(ptrTypeParameterList)>
__intrinsic_op($(kIROp_CastPtrToInt))
- __init(Ptr<T, addrSpace> ptr);
+ __init($(fullPtrType) ptr);
static const uint64_t maxValue = 0xFFFFFFFFFFFFFFFFULL;
static const uint64_t minValue = 0;
@@ -1417,9 +1479,9 @@ extension uint64_t : IRangedValue
extension int64_t : IRangedValue
{
- __generic<T, let addrSpace : uint64_t>
+ __generic<$(ptrTypeParameterList)>
__intrinsic_op($(kIROp_CastPtrToInt))
- __init(Ptr<T, addrSpace> ptr);
+ __init($(fullPtrType) ptr);
static const int64_t maxValue = 0x7FFFFFFFFFFFFFFFLL;
static const int64_t minValue = -0x8000000000000000LL;
@@ -1427,9 +1489,9 @@ extension int64_t : IRangedValue
extension intptr_t : IRangedValue
{
- __generic<T, let addrSpace : uint64_t>
+ __generic<$(ptrTypeParameterList)>
__intrinsic_op($(kIROp_CastPtrToInt))
- __init(Ptr<T, addrSpace> ptr);
+ __init($(fullPtrType) ptr);
static const intptr_t maxValue = $(SLANG_PROCESSOR_X86_64?"0x7FFFFFFFFFFFFFFFz":"0x7FFFFFFFz");
static const intptr_t minValue = $(SLANG_PROCESSOR_X86_64?"0x8000000000000000z":"0x80000000z");
static const int size = $(SLANG_PROCESSOR_X86_64?"8":"4");
@@ -1437,9 +1499,9 @@ extension intptr_t : IRangedValue
extension uintptr_t : IRangedValue
{
- __generic<T, let addrSpace : uint64_t>
+ __generic<$(ptrTypeParameterList)>
__intrinsic_op($(kIROp_CastPtrToInt))
- __init(Ptr<T, addrSpace> ptr);
+ __init($(fullPtrType) ptr);
static const uintptr_t maxValue = $(SLANG_PROCESSOR_X86_64?"0xFFFFFFFFFFFFFFFFz":"0xFFFFFFFFz");
static const uintptr_t minValue = 0z;
static const int size = $(SLANG_PROCESSOR_X86_64?"8":"4");
@@ -1470,7 +1532,9 @@ __intrinsic_type($(kIROp_ConstRefType))
struct ConstRef
{};
-typealias __Addr<T> = Ptr<T, $((uint64_t)AddressSpace::Generic)ULL>;
+// __Addr<T> is AddressSpace::Generic since Slang will specalize & validate the address-space
+// internally to a concrete address-space.
+typealias __Addr<T> = Ptr<T, Access::ReadWrite, (AddressSpace)$((uint64_t)AddressSpace::Generic)>;
//@public:
@@ -1828,16 +1892,16 @@ struct NativeString
__init() { this = NativeString(""); }
};
-extension Ptr<void>
+extension<Access access> Ptr<void, access>
{
__implicit_conversion($(kConversionCost_PtrToVoidPtr))
[__unsafeForceInlineEarly]
- __init(NativeString nativeStr) { this = nativeStr.getBuffer(); }
+ __init(NativeString nativeStr) { this = Ptr<void, access>(nativeStr.getBuffer()); }
- __generic<T, let addrSpace : uint64_t>
+ __generic<$(ptrTypeParameterList)>
__intrinsic_op($(kIROp_BitCast))
__implicit_conversion($(kConversionCost_PtrToVoidPtr))
- __init(Ptr<T, addrSpace> ptr);
+ __init($(fullPtrType) ptr);
__generic<T>
__intrinsic_op($(kIROp_BitCast))
@@ -2607,29 +2671,31 @@ for (auto op : intrinsicUnaryOps)
}}}}
-__generic<T, let addrSpace : uint64_t>
+// Only ReadWrite is an L-value.
+__generic<T, AddressSpace addrSpace>
__intrinsic_op(0)
-[require(cpp_cuda_spirv)]
-__prefix Ref<T> operator*(Ptr<T, addrSpace> value);
+__prefix Ref<T> operator*(Ptr<T, Access::ReadWrite, addrSpace> value);
-__generic<T>
+// Unknown access qualifier or Access::Read access qualifier is a promise
+// that the pointer is not going to be used as an L-value.
+__generic<$(ptrTypeParameterList)>
__intrinsic_op(0)
-[KnownBuiltin($( (int)KnownBuiltinDeclName::OperatorAddressOf))]
-[require(cpp_cuda_spirv)]
-__prefix Ptr<T, $((uint64_t)AddressSpace::UserPointer)ULL> operator&(__ref T value);
+__prefix ConstRef<T> operator*($(fullPtrType) value);
+// TODO: [require(cpu)]. This cannot be done yet since this change breaks slangpy
__generic<T>
__intrinsic_op(0)
+[KnownBuiltin( $((int)KnownBuiltinDeclName::OperatorAddressOf))]
[require(cpp_cuda_spirv)]
-__Addr<T> __get_addr( __ref T value);
+__prefix Ptr<T, Access::ReadWrite, AddressSpace::Device> operator&(__ref T value);
-__generic<T, let addrSpace : uint64_t>
+__generic<$(ptrTypeParameterList), TInt : __BuiltinIntegerType>
__intrinsic_op($(kIROp_GetOffsetPtr))
-Ptr<T, addrSpace> operator+(Ptr<T, addrSpace> value, int64_t offset);
+$(fullPtrType) operator+($(fullPtrType) value, TInt offset);
-__generic<T, let addrSpace : uint64_t>
+__generic<$(ptrTypeParameterList), TInt : __BuiltinIntegerType>
[__unsafeForceInlineEarly]
-Ptr<T, addrSpace> operator -(Ptr<T, addrSpace> value, int64_t offset)
+$(fullPtrType) operator-($(fullPtrType) value, TInt offset)
{
return __getOffsetPtr(value, -offset);
}
@@ -2694,9 +2760,9 @@ matrix<T,R,C> operator$(op.name)(in out matrix<T,R,C,L> value)
{$(fixity.bodyPrefix) value = value $(op.binOp) __builtin_cast<T>(1); return $(fixity.returnVal); }
$(fixity.qual)
-__generic<T, let addrSpace : uint64_t>
+__generic<$(ptrTypeParameterList)>
[__unsafeForceInlineEarly]
-Ptr<T, addrSpace> operator$(op.name)(in out Ptr<T, addrSpace> value)
+$(fullPtrType) operator$(op.name)(in out $(fullPtrType) value)
{$(fixity.bodyPrefix) value = value $(op.binOp) 1; return $(fixity.returnVal); }
${{{{
@@ -3556,18 +3622,6 @@ enum MemoryOrder
SeqCst = $(kIRMemoryOrder_SeqCst),
}
-// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_scope_id
-enum MemoryScope
-{
- CrossDevice = 0,
- Device = 1,
- Workgroup = 2,
- Subgroup = 3,
- Invocation = 4,
- QueueFamily = 5,
- ShaderCallKHR = 6,
-};
-
/// Represents types that can be used in any atomic operations.
/// Implemented by builtin scalar types: `int`, `uint`, `int64_t`, `uint64_t`, `int8_t`, `uint8_t`, `int16_t`, `uint16_t`, `float`, `double` and `half`.
[sealed] interface IAtomicable {}
@@ -4307,7 +4361,7 @@ __attributeTarget(FuncDecl)
attribute_syntax [RequireFullQuads] : RequireFullQuadsAttribute;
__generic<T>
-typealias NodePayloadPtr = Ptr<T, $((uint64_t)AddressSpace::NodePayloadAMDX)>;
+typealias NodePayloadPtr = Ptr<T, Access::ReadWrite, (AddressSpace)$((uint64_t)AddressSpace::NodePayloadAMDX)>;
__attributeTarget(StructDecl)
attribute_syntax [raypayload] : RayPayloadAttribute;
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 23afb3297..2af0dbcf7 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -19859,7 +19859,7 @@ __Addr<T> __allocHitObjectAttributes<T>()
{
[__vulkanHitObjectAttributes]
static T t;
- return __get_addr(t);
+ return __getAddress(t);
}
// Next is the custom intrinsic that will compute the hitObjectAttributes location
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index a71abf570..5da4e9521 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -461,9 +461,17 @@ Type* ASTBuilder::getSpecializedBuiltinType(ArrayView<Val*> genericArgs, const c
return rsType;
}
-PtrType* ASTBuilder::getPtrType(Type* valueType, AddressSpace addrSpace)
+PtrType* ASTBuilder::getPtrType(Type* valueType, Val* accessQualifier, Val* addrSpace)
{
- return dynamicCast<PtrType>(getPtrType(valueType, addrSpace, "PtrType"));
+ return dynamicCast<PtrType>(getPtrType(valueType, accessQualifier, addrSpace, "PtrType"));
+}
+
+PtrType* ASTBuilder::getPtrType(
+ Type* valueType,
+ AccessQualifier accessQualifier,
+ AddressSpace addrSpace)
+{
+ return dynamicCast<PtrType>(getPtrType(valueType, accessQualifier, addrSpace, "PtrType"));
}
Type* ASTBuilder::getDefaultLayoutType()
@@ -489,11 +497,6 @@ Type* ASTBuilder::getScalarLayoutType()
return getSpecializedBuiltinType({}, "ScalarDataLayoutType");
}
-Type* ASTBuilder::getCLayoutType()
-{
- return getSpecializedBuiltinType({}, "CDataLayoutType");
-}
-
// Construct the type `Out<valueType>`
OutType* ASTBuilder::getOutType(Type* valueType)
{
@@ -505,9 +508,9 @@ InOutType* ASTBuilder::getInOutType(Type* valueType)
return dynamicCast<InOutType>(getPtrType(valueType, "InOutType"));
}
-RefType* ASTBuilder::getRefType(Type* valueType, AddressSpace addrSpace)
+RefType* ASTBuilder::getRefType(Type* valueType)
{
- return dynamicCast<RefType>(getPtrType(valueType, addrSpace, "RefType"));
+ return dynamicCast<RefType>(getPtrType(valueType, "RefType"));
}
ConstRefType* ASTBuilder::getConstRefType(Type* valueType)
@@ -528,13 +531,27 @@ PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, char const* ptrTypeName)
PtrTypeBase* ASTBuilder::getPtrType(
Type* valueType,
- AddressSpace addrSpace,
+ Val* accessQualifier,
+ Val* addrSpace,
char const* ptrTypeName)
{
- Val* args[] = {valueType, getIntVal(getUInt64Type(), (IntegerLiteralValue)addrSpace)};
+ Val* args[] = {valueType, accessQualifier, addrSpace};
return as<PtrTypeBase>(getSpecializedBuiltinType(makeArrayView(args), ptrTypeName));
}
+PtrTypeBase* ASTBuilder::getPtrType(
+ Type* valueType,
+ AccessQualifier accessQualifier,
+ AddressSpace addrSpace,
+ char const* ptrTypeName)
+{
+ return as<PtrTypeBase>(getPtrType(
+ valueType,
+ getIntVal(getBuiltinType(BaseType::AccessQualifier), (IntegerLiteralValue)accessQualifier),
+ getIntVal(getBuiltinType(BaseType::AddressSpace), (IntegerLiteralValue)addrSpace),
+ ptrTypeName));
+}
+
ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* elementCount)
{
if (!elementCount)
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h
index 798e1ddc0..e71e2665f 100644
--- a/source/slang/slang-ast-builder.h
+++ b/source/slang/slang-ast-builder.h
@@ -530,7 +530,8 @@ public:
Type* getDiffInterfaceType() { return m_sharedASTBuilder->getDiffInterfaceType(); }
// Construct the type `Ptr<valueType>`, where `Ptr`
// is looked up as a builtin type.
- PtrType* getPtrType(Type* valueType, AddressSpace addrSpace);
+ PtrType* getPtrType(Type* valueType, AccessQualifier accessQualifier, AddressSpace addrSpace);
+ PtrType* getPtrType(Type* valueType, Val* accessQualifier, Val* addrSpace);
// Construct the type `Out<valueType>`
OutType* getOutType(Type* valueType);
@@ -539,7 +540,7 @@ public:
InOutType* getInOutType(Type* valueType);
// Construct the type `Ref<valueType>`
- RefType* getRefType(Type* valueType, AddressSpace addrSpace);
+ RefType* getRefType(Type* valueType);
// Construct the type `ConstRef<valueType>`
ConstRefType* getConstRefType(Type* valueType);
@@ -550,7 +551,16 @@ public:
// Construct a pointer type like `Ptr<valueType>`, but where
// the actual type name for the pointer type is given by `ptrTypeName`
PtrTypeBase* getPtrType(Type* valueType, char const* ptrTypeName);
- PtrTypeBase* getPtrType(Type* valueType, AddressSpace addrSpace, char const* ptrTypeName);
+ PtrTypeBase* getPtrType(
+ Type* valueType,
+ Val* accessQualifier,
+ Val* addrSpace,
+ char const* ptrTypeName);
+ PtrTypeBase* getPtrType(
+ Type* valueType,
+ AccessQualifier accessQualifier,
+ AddressSpace addrSpace,
+ char const* ptrTypeName);
ArrayExpressionType* getArrayType(Type* elementType, IntVal* elementCount);
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index adbc7a2ba..fb0ac2a67 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -503,6 +503,13 @@ class CountOfExpr : public SizeOfLikeExpr
};
FIDDLE()
+class AddressOfExpr : public Expr
+{
+ FIDDLE(...)
+ FIDDLE() Expr* arg = nullptr;
+};
+
+FIDDLE()
class MakeOptionalExpr : public Expr
{
FIDDLE(...)
diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h
index c29a42665..379de0560 100644
--- a/source/slang/slang-ast-iterator.h
+++ b/source/slang/slang-ast-iterator.h
@@ -251,6 +251,7 @@ struct ASTIterator
iterator->maybeDispatchCallback(expr);
}
void visitReturnValExpr(ReturnValExpr* expr) { iterator->maybeDispatchCallback(expr); }
+ void visitAddressOfExpr(AddressOfExpr* expr) { iterator->maybeDispatchCallback(expr); }
void visitAndTypeExpr(AndTypeExpr* expr)
{
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 2f9c29d6c..89f7a70bb 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -583,7 +583,6 @@ class BuiltinRequirementModifier : public Modifier
FIDDLE() BuiltinRequirementKind kind;
};
-
// A modifier applied to declarations of builtin types to indicate how they
// should be lowered to the IR.
//
diff --git a/source/slang/slang-ast-print.cpp b/source/slang/slang-ast-print.cpp
index 1f7478662..299e6a859 100644
--- a/source/slang/slang-ast-print.cpp
+++ b/source/slang/slang-ast-print.cpp
@@ -610,6 +610,15 @@ void ASTPrinter::addExpr(Expr* expr)
}
sb << ")";
}
+ else if (const auto addressOfExpr = as<AddressOfExpr>(expr))
+ {
+ sb << "__getAddress(";
+ if (addressOfExpr->arg)
+ {
+ addExpr(addressOfExpr->arg);
+ }
+ sb << ")";
+ }
else if (const auto makeOptionalExpr = as<MakeOptionalExpr>(expr))
{
if (makeOptionalExpr->value)
diff --git a/source/slang/slang-ast-support-types.cpp b/source/slang/slang-ast-support-types.cpp
index 3ac352f0a..1f5f4f8b2 100644
--- a/source/slang/slang-ast-support-types.cpp
+++ b/source/slang/slang-ast-support-types.cpp
@@ -15,6 +15,10 @@ QualType::QualType(Type* type)
{
isLeftValue = true;
}
+ else if (as<ConstRefType>(type))
+ {
+ isLeftValue = false;
+ }
}
void removeModifier(ModifiableSyntaxNode* syntax, Modifier* toRemove)
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index 8a224b305..53f6626d7 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -439,11 +439,28 @@ Type* NativeRefType::getValueType()
return as<Type>(_getGenericTypeArg(this, 0));
}
-Val* PtrTypeBase::getAddressSpace()
+
+Val* PtrTypeBase::getAccessQualifier()
{
return _getGenericTypeArg(this, 1);
}
+Val* PtrTypeBase::getAddressSpace()
+{
+ return _getGenericTypeArg(this, 2);
+}
+
+AccessQualifier tryGetAccessQualifierValue(Val* val)
+{
+ AccessQualifier accessQualifier = AccessQualifier::ReadWrite;
+
+ if (auto cintVal = as<ConstantIntVal>(val))
+ {
+ accessQualifier = (AccessQualifier)(cintVal->getValue());
+ }
+ return accessQualifier;
+}
+
AddressSpace tryGetAddressSpaceValue(Val* addrSpaceVal)
{
AddressSpace addrSpace = AddressSpace::Generic;
@@ -460,19 +477,38 @@ void maybePrintAddrSpaceOperand(StringBuilder& out, AddressSpace addrSpace)
switch (addrSpace)
{
case AddressSpace::Generic:
+ out << toSlice(", AddressSpace::Generic");
+ break;
case AddressSpace::UserPointer:
+ // We expose UserPointer as Device to users
+ out << toSlice(", AddressSpace::Device");
break;
case AddressSpace::GroupShared:
- out << toSlice(", groupshared");
+ out << toSlice(", AddressSpace::GroupShared");
break;
case AddressSpace::Global:
- out << toSlice(", global");
+ out << toSlice(", AddressSpace::Global");
break;
case AddressSpace::ThreadLocal:
- out << toSlice(", threadlocal");
+ out << toSlice(", AddressSpace::ThreadLocal");
break;
case AddressSpace::Uniform:
- out << toSlice(", uniform");
+ out << toSlice(", AddressSpace::Uniform");
+ break;
+ default:
+ break;
+ }
+}
+
+void maybePrintAccessQualifierOperand(StringBuilder& out, AccessQualifier accessQualifier)
+{
+ switch (accessQualifier)
+ {
+ case AccessQualifier::ReadWrite:
+ out << toSlice(", Access::ReadWrite");
+ break;
+ case AccessQualifier::Read:
+ out << toSlice(", Access::Read");
break;
default:
break;
@@ -481,20 +517,21 @@ void maybePrintAddrSpaceOperand(StringBuilder& out, AddressSpace addrSpace)
void PtrType::_toTextOverride(StringBuilder& out)
{
+ auto accessQualifier = tryGetAccessQualifierValue(getAccessQualifier());
auto addrSpace = tryGetAddressSpaceValue(getAddressSpace());
- if (addrSpace == AddressSpace::Generic)
- out << toSlice("Addr<") << getValueType();
- else
- out << toSlice("Ptr<") << getValueType();
+ out << toSlice("Ptr<") << getValueType();
+ maybePrintAccessQualifierOperand(out, accessQualifier);
maybePrintAddrSpaceOperand(out, addrSpace);
out << toSlice(">");
}
void RefType::_toTextOverride(StringBuilder& out)
{
+ auto accessQualifier = tryGetAccessQualifierValue(getAccessQualifier());
+ auto addrSpace = tryGetAddressSpaceValue(getAddressSpace());
out << toSlice("Ref<") << getValueType();
- auto addressSpaceVal = getAddressSpace();
- maybePrintAddrSpaceOperand(out, tryGetAddressSpaceValue(addressSpaceVal));
+ maybePrintAccessQualifierOperand(out, accessQualifier);
+ maybePrintAddrSpaceOperand(out, addrSpace);
out << toSlice(">");
}
diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h
index 842af8b88..4994328b2 100644
--- a/source/slang/slang-ast-type.h
+++ b/source/slang/slang-ast-type.h
@@ -685,7 +685,7 @@ class PtrTypeBase : public BuiltinType
FIDDLE(...)
// Get the type of the pointed-to value.
Type* getValueType();
-
+ Val* getAccessQualifier();
Val* getAddressSpace();
};
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index 43bd99f19..96f0a1682 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -1435,6 +1435,10 @@ Val* TypeCastIntVal::tryFoldImpl(
case BaseType::UInt8:
resultValue = (uint8_t)resultValue;
return true;
+ case BaseType::AddressSpace:
+ case BaseType::AccessQualifier:
+ case BaseType::MemoryScope:
+ return true;
default:
return false;
}
diff --git a/source/slang/slang-base-type-info.cpp b/source/slang/slang-base-type-info.cpp
index 9072e34e4..984437ca8 100644
--- a/source/slang/slang-base-type-info.cpp
+++ b/source/slang/slang-base-type-info.cpp
@@ -4,7 +4,7 @@
namespace Slang
{
-/* static */ const BaseTypeInfo BaseTypeInfo::s_info[Index(BaseType::CountOf)] = {
+/* static */ const BaseTypeInfo BaseTypeInfo::s_info[Index(BaseType::CountOfPrimitives)] = {
{0, 0, uint8_t(BaseType::Void)},
{uint8_t(sizeof(bool)), 0, uint8_t(BaseType::Bool)},
{uint8_t(sizeof(int8_t)),
@@ -84,6 +84,12 @@ namespace Slang
return UnownedStringSlice::fromLiteral("intptr_t");
case BaseType::UIntPtr:
return UnownedStringSlice::fromLiteral("uintptr_t");
+ case BaseType::AddressSpace:
+ return UnownedStringSlice::fromLiteral("AddressSpace");
+ case BaseType::MemoryScope:
+ return UnownedStringSlice::fromLiteral("MemoryScope");
+ case BaseType::AccessQualifier:
+ return UnownedStringSlice::fromLiteral("Access");
default:
{
SLANG_ASSERT(!"Unknown basic type");
diff --git a/source/slang/slang-base-type-info.h b/source/slang/slang-base-type-info.h
index 4b96af18f..bad70c6fa 100644
--- a/source/slang/slang-base-type-info.h
+++ b/source/slang/slang-base-type-info.h
@@ -44,7 +44,7 @@ struct BaseTypeInfo
static bool check();
private:
- static const BaseTypeInfo s_info[Index(BaseType::CountOf)];
+ static const BaseTypeInfo s_info[Index(BaseType::CountOfPrimitives)];
};
} // namespace Slang
diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef
index 0ea43a8df..822356312 100644
--- a/source/slang/slang-capabilities.capdef
+++ b/source/slang/slang-capabilities.capdef
@@ -271,6 +271,10 @@ alias cpp_cuda = cpp | cuda;
/// [Compound]
alias cpp_cuda_spirv = cpp | cuda | spirv;
+/// CPP, CUDA, Metal, and SPIRV code-gen targets
+/// [Compound]
+alias cpp_cuda_metal_spirv = cpp | cuda | metal | spirv;
+
/// CUDA and SPIRV code-gen targets
/// [Compound]
alias cuda_spirv = cuda | spirv;
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp
index 7c6f8929a..9355834b5 100644
--- a/source/slang/slang-check-constraint.cpp
+++ b/source/slang/slang-check-constraint.cpp
@@ -121,7 +121,7 @@ Type* SemanticsVisitor::_tryJoinTypeWithInterface(
ConversionCost bestCost = kConversionCost_Explicit;
if (auto basicType = dynamicCast<BasicExpressionType>(type))
{
- for (Int baseTypeFlavorIndex = 0; baseTypeFlavorIndex < Int(BaseType::CountOf);
+ for (Int baseTypeFlavorIndex = 0; baseTypeFlavorIndex < Int(BaseType::CountOfPrimitives);
baseTypeFlavorIndex++)
{
// Don't consider `type`, since we already know it doesn't work.
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index 758c23a5f..4d15fd840 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -1760,6 +1760,13 @@ bool SemanticsVisitor::_coerce(
if (sink)
{
sink->diagnose(fromExpr, Diagnostics::ambiguousConversion, fromType, toType);
+ for (auto candidate : overloadContext.bestCandidates)
+ {
+ sink->diagnose(
+ candidate.item.declRef,
+ Diagnostics::seeDeclarationOf,
+ candidate.item.declRef);
+ }
}
*outToExpr = CreateErrorExpr(fromExpr);
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 507e12fa6..e59cf6ad5 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -678,8 +678,12 @@ struct SemanticsDeclReferenceVisitor : public SemanticsDeclVisitorBase,
return;
return DeclVisitor<VisitorType>::dispatch(val);
}
+
// Expr Visitor
void visitExpr(Expr*) {}
+
+ void visitOpenRefExpr(OpenRefExpr* expr) { dispatchIfNotNull(expr->innerExpr); }
+
void visitIndexExpr(IndexExpr* subscriptExpr)
{
for (auto arg : subscriptExpr->indexExprs)
@@ -695,6 +699,7 @@ struct SemanticsDeclReferenceVisitor : public SemanticsDeclVisitorBase,
dispatchIfNotNull(element);
}
+ void visitAddressOfExpr(AddressOfExpr* expr) { dispatchIfNotNull(expr->arg); }
void visitAssignExpr(AssignExpr* expr)
{
@@ -2360,6 +2365,13 @@ void SemanticsDeclHeaderVisitor::checkVarDeclCommon(VarDeclBase* varDecl)
addModifier(varDecl, m_astBuilder->create<ExternCppModifier>());
}
+ // Not allowed a `globallycoherent T*` or related
+ if (as<PtrType>(varDecl->type))
+ if (auto memoryQualifierSet = varDecl->findModifier<MemoryQualifierSetModifier>())
+ if (memoryQualifierSet->getMemoryQualifierBit() &
+ MemoryQualifierSetModifier::Flags::kCoherent)
+ getSink()->diagnose(varDecl, Diagnostics::coherentKeywordOnAPointer);
+
// Check for static const variables without initializers
if (!varDecl->initExpr)
{
@@ -14379,6 +14391,12 @@ struct CapabilityDeclReferenceVisitor
{
handleProcessFunc(stmt, CapabilitySet(CapabilityName::fragment), stmt->loc);
}
+ void visitAddressOfExpr(AddressOfExpr* expr)
+ {
+ // __getAddress only works with certain targets
+ handleProcessFunc(expr, CapabilitySet(CapabilityName::cpp_cuda_metal_spirv), expr->loc);
+ this->dispatchIfNotNull(expr->arg);
+ }
void visitTargetSwitchStmt(TargetSwitchStmt* stmt)
{
CapabilitySet set;
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index a874eaf43..ec249f56d 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -230,6 +230,7 @@ Expr* SemanticsVisitor::maybeOpenRef(Expr* expr)
openRef->type.isLeftValue = (as<RefType>(exprType) != nullptr);
openRef->type.type = refType->getValueType();
openRef->checked = true;
+ openRef->loc = expr->loc;
return openRef;
}
return expr;
@@ -4111,6 +4112,167 @@ Expr* SemanticsExprVisitor::visitSizeOfLikeExpr(SizeOfLikeExpr* sizeOfLikeExpr)
return sizeOfLikeExpr;
}
+// Determines if we have a valid `AddressOf` target.
+// Target to validate is `baseExpr`.
+// Original type is `targetType`.
+static PtrType* getValidTypeForAddressOf(
+ SemanticsVisitor* visitor,
+ ASTBuilder* m_astBuilder,
+ Expr* baseExpr,
+ Type* targetType)
+{
+
+ // If our base is a variable like expression, we should check if this expr is a
+ // block of memory we allow getting the address of.
+ if (auto declRefExpr = as<DeclRefExpr>(baseExpr))
+ {
+ visitor->ensureDecl(declRefExpr->declRef, DeclCheckState::DefinitionChecked);
+ if (auto varDeclRef = as<VarDeclBase>(declRefExpr->declRef))
+ {
+ auto variableType = varDeclRef.substitute(m_astBuilder, targetType);
+ auto varDecl = varDeclRef.getDecl();
+ bool hasVulkanHitObjectAttributesAttribute = false;
+ bool hasHLSLGroupSharedModifier = false;
+ for (auto modifier : varDecl->modifiers)
+ {
+ if (as<VulkanHitObjectAttributesAttribute>(modifier))
+ hasVulkanHitObjectAttributesAttribute = true;
+ else if (as<HLSLGroupSharedModifier>(modifier))
+ hasHLSLGroupSharedModifier = true;
+
+ if (hasVulkanHitObjectAttributesAttribute || hasHLSLGroupSharedModifier)
+ break;
+ }
+
+ // Handle variables tagged as [__vulkanHitObjectAttributes].
+ // This support is needed for an internal "hack" Slang uses
+ // for raytracing with `__allocHitObjectAttributes`.
+ if (hasVulkanHitObjectAttributesAttribute)
+ {
+ return m_astBuilder->getPtrType(
+ variableType,
+ AccessQualifier::ReadWrite,
+ AddressSpace::Generic);
+ }
+ // Handle 'groupshared' variables.
+ else if (hasHLSLGroupSharedModifier)
+ {
+ return m_astBuilder->getPtrType(
+ variableType,
+ AccessQualifier::ReadWrite,
+ AddressSpace::GroupShared);
+ }
+ }
+ }
+
+ // If our base is a variable like expression, which comes from a deref-like operation,
+ // we should check if we are able to return a pointer from that base.
+ auto getPtrTypeFromBaseOfDerefLikeOperation = [&](Expr* baseExpr) -> PtrType*
+ {
+ auto declRefExpr = as<DeclRefExpr>(baseExpr);
+ if (!declRefExpr)
+ return nullptr;
+ visitor->ensureDecl(declRefExpr->declRef, DeclCheckState::DefinitionChecked);
+ auto varDeclRef = as<VarDeclBase>(declRefExpr->declRef);
+ if (!varDeclRef)
+ return nullptr;
+
+ auto variableType = varDeclRef.substitute(m_astBuilder, targetType);
+
+ auto ptrType = as<PtrType>(getType(m_astBuilder, varDeclRef));
+ if (!ptrType)
+ return nullptr;
+
+ return m_astBuilder->getPtrType(
+ variableType,
+ ptrType->getAccessQualifier(),
+ ptrType->getAddressSpace());
+ };
+
+ // This logic handles the recursive lookup of "does our operation lead up
+ // to an addressessable (can take the address-of) section of memory".
+ if (auto indexExpr = as<IndexExpr>(baseExpr))
+ {
+ // If a user chooses to index into an array, we should check if the base
+ // expression is something we can get the address-of.
+ return getValidTypeForAddressOf(
+ visitor,
+ m_astBuilder,
+ indexExpr->baseExpression,
+ targetType);
+ }
+ else if (auto memberExpr = as<MemberExpr>(baseExpr))
+ {
+ // If a user chooses to get a member of a base, we should check if the base
+ // is something we can get the address-of.
+ if (as<VarDeclBase>(memberExpr->declRef))
+ return getValidTypeForAddressOf(
+ visitor,
+ m_astBuilder,
+ memberExpr->baseExpression,
+ targetType);
+ }
+ else if (auto derefExpr = as<DerefExpr>(baseExpr))
+ {
+ // If a user deref's a variable-like-expression, we should
+ // check if this is a base expression we can get the address-of.
+ return getPtrTypeFromBaseOfDerefLikeOperation(derefExpr->base);
+ }
+ else if (auto invokeExpr = as<InvokeExpr>(baseExpr))
+ {
+ // We only want to allow function calls if we are getting the address
+ // of a `GetOffsetPtr` to a pointer-variable
+ auto functionMemberExpr = as<MemberExpr>(invokeExpr->functionExpr);
+ if (!functionMemberExpr)
+ return nullptr;
+ auto subscriptDecl = as<SubscriptDecl>(functionMemberExpr->declRef.getDecl());
+ if (!subscriptDecl)
+ return nullptr;
+ bool isOffsetIntrinsicOp = false;
+ for (auto refAccessor : subscriptDecl->getMembersOfType<RefAccessorDecl>())
+ {
+ auto intrinsicOp = refAccessor->findModifier<IntrinsicOpModifier>();
+ if (!intrinsicOp)
+ continue;
+ if (intrinsicOp->op != kIROp_GetOffsetPtr)
+ continue;
+ isOffsetIntrinsicOp = true;
+ }
+ if (!isOffsetIntrinsicOp)
+ return nullptr;
+
+ return getPtrTypeFromBaseOfDerefLikeOperation(functionMemberExpr->baseExpression);
+ }
+ else if (auto swizzleExpr = as<SwizzleExpr>(baseExpr))
+ {
+ // Only allow swizzle of 1 element since otherwise
+ // we may have a non-contiguous swizzle
+ // (`val.xxy` is non contiguous).
+ if (swizzleExpr->elementIndices.getCount() > 1)
+ return nullptr;
+
+ // Check if the base expression is something we can get the address-of.
+ return getValidTypeForAddressOf(visitor, m_astBuilder, swizzleExpr->base, targetType);
+ }
+ return nullptr;
+}
+
+Expr* SemanticsExprVisitor::visitAddressOfExpr(AddressOfExpr* expr)
+{
+ expr->arg = CheckTerm(expr->arg);
+
+ // This address-of feature is purely experimental and for prototyping.
+ // Only allow known expressions.
+ expr->type =
+ getValidTypeForAddressOf(this, m_astBuilder, expr->arg, getType(m_astBuilder, expr->arg));
+ if (!expr->type)
+ {
+ getSink()->diagnose(expr, Diagnostics::invalidAddressOf);
+ expr->type = m_astBuilder->getErrorType();
+ }
+ return expr;
+}
+
Expr* SemanticsExprVisitor::visitBuiltinCastExpr(BuiltinCastExpr* expr)
{
// All builtin cast exprs should already be checked.
@@ -5744,7 +5906,10 @@ Expr* SemanticsExprVisitor::visitPointerTypeExpr(PointerTypeExpr* expr)
expr->base = CheckProperType(expr->base);
if (as<ErrorType>(expr->base.type))
expr->type = expr->base.type;
- auto ptrType = m_astBuilder->getPtrType(expr->base.type, AddressSpace::UserPointer);
+ auto ptrType = m_astBuilder->getPtrType(
+ expr->base.type,
+ AccessQualifier::ReadWrite,
+ AddressSpace::UserPointer);
expr->type = m_astBuilder->getTypeType(ptrType);
return expr;
}
@@ -5830,6 +5995,11 @@ Val* SemanticsExprVisitor::checkTypeModifier(Modifier* modifier, Type* type)
{
return m_astBuilder->getNoDiffModifierVal();
}
+ else if (as<ConstModifier>(modifier))
+ {
+ getSink()->diagnose(modifier, Diagnostics::constNotAllowedOnType);
+ return nullptr;
+ }
else
{
// TODO: more complete error message here
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 523041697..e6c66ddd3 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -3037,7 +3037,7 @@ public:
}
Expr* visitSizeOfLikeExpr(SizeOfLikeExpr* expr);
-
+ Expr* visitAddressOfExpr(AddressOfExpr* expr);
Expr* visitIncompleteExpr(IncompleteExpr* expr);
Expr* visitBoolLiteralExpr(BoolLiteralExpr* expr);
Expr* visitNullPtrLiteralExpr(NullPtrLiteralExpr* expr);
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index db477ac25..ebda2d637 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -1677,6 +1677,18 @@ Modifier* SemanticsVisitor::checkModifier(
}
}
+ if (as<ConstModifier>(m))
+ {
+ if (auto varDeclBase = as<VarDeclBase>(syntaxNode))
+ {
+ if (as<PointerTypeExpr>(varDeclBase->type.exp))
+ {
+ // Disallow `const T*` syntax.
+ getSink()->diagnose(m, Diagnostics::constNotAllowedOnCStylePtrDecl);
+ return nullptr;
+ }
+ }
+ }
if (auto glslLayoutAttribute = as<UncheckedGLSLLayoutAttribute>(m))
{
return checkGLSLLayoutAttribute(glslLayoutAttribute, syntaxNode);
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index f949e2632..e0dbc7e08 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -1196,8 +1196,25 @@ Expr* SemanticsVisitor::CompleteOverloadCandidate(
{
// If the subscript decl has a setter,
// then the call is an l-value if base is l-value.
+ //
+ // If Ptr<T, Access> we only need to check for ReadWrite
+ // Access (if ReadWrite result is an LValue. By default a
+ // Ptr<...> is Read-only (unresolved generic argument & Access::Read).
if (auto base = GetBaseExpr(baseExpr))
{
+ if (auto ptrTypeBase = as<PtrTypeBase>(base->type))
+ {
+ auto accessQualifier =
+ as<ConstantIntVal>(ptrTypeBase->getAccessQualifier());
+ if (!accessQualifier ||
+ AccessQualifier(accessQualifier->getValue()) ==
+ AccessQualifier::ReadWrite)
+ {
+ callExpr->type.isLeftValue = true;
+ }
+ break;
+ }
+
if (base->type.isLeftValue)
{
callExpr->type.isLeftValue = true;
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp
index 45deea109..7c9111629 100644
--- a/source/slang/slang-check-shader.cpp
+++ b/source/slang/slang-check-shader.cpp
@@ -784,7 +784,7 @@ Type* getParamTypeWithDirectionWrapper(ASTBuilder* astBuilder, DeclRef<VarDeclBa
case kParameterDirection_InOut:
return astBuilder->getInOutType(result);
case kParameterDirection_Ref:
- return astBuilder->getRefType(result, AddressSpace::Generic);
+ return astBuilder->getRefType(result);
default:
return result;
}
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index ecc3fb3f3..a30b5f362 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -524,6 +524,14 @@ DIAGNOSTIC(
Error,
missingLayoutBindingModifier,
"Expecting 'binding' modifier in the layout qualifier here")
+DIAGNOSTIC(
+ 20017,
+ Error,
+ constNotAllowedOnCStylePtrDecl,
+ "'const' not allowed on pointer typed declarations using the C style '*' operator. "
+ "If the intent is to restrict the pointed-to value to read-only, use 'Ptr<T, Access.Read>'; "
+ "if the intent is to make the pointer itself immutable, use 'let' or 'const Ptr<...>'.")
+DIAGNOSTIC(20018, Error, constNotAllowedOnType, "cannot use 'const' as a type modifier")
DIAGNOSTIC(
20101,
@@ -702,11 +710,6 @@ DIAGNOSTIC(
argumentExpectedLValue,
"argument passed to parameter '$0' must be l-value.")
DIAGNOSTIC(
- 30078,
- Error,
- cannotTakeConstantPointers,
- "Not allowed to take pointer of an immutable object")
-DIAGNOSTIC(
30048,
Error,
argumentHasMoreMemoryQualifiersThanParam,
@@ -823,7 +826,17 @@ DIAGNOSTIC(
"function, you can replace '$2 $0' with a generic 'T $0' and a 'where T : $2' constraint.")
DIAGNOSTIC(-1, Note, doYouMeanStaticConst, "do you intend to define a `static const` instead?")
DIAGNOSTIC(-1, Note, doYouMeanUniform, "do you intend to define a `uniform` parameter instead?")
-
+DIAGNOSTIC(
+ 30078,
+ Error,
+ coherentKeywordOnAPointer,
+ "cannot have a `globallycoherent T*` or a `coherent T*`, use explicit methods for coherent "
+ "operations instead")
+DIAGNOSTIC(
+ 30079,
+ Error,
+ cannotTakeConstantPointers,
+ "Not allowed to take the address of an immutable object")
DIAGNOSTIC(
30100,
Error,
@@ -927,11 +940,7 @@ DIAGNOSTIC(
Note,
noteExplicitConversionPossible,
"explicit conversion from '$0' to '$1' is possible")
-DIAGNOSTIC(
- 30080,
- Error,
- ambiguousConversion,
- "more than one implicit conversion exists from '$0' to '$1'")
+DIAGNOSTIC(30080, Error, ambiguousConversion, "more than one conversion exists from '$0' to '$1'")
DIAGNOSTIC(
30081,
Warning,
@@ -1432,7 +1441,11 @@ DIAGNOSTIC(
"If this is intended, consider using [NoDiffThis] on the function '$1' to suppress this "
"warning. Alternatively, users can mark the parent struct as [Differentiable] to propagate "
"derivatives.")
-
+DIAGNOSTIC(
+ 31160,
+ Error,
+ invalidAddressOf,
+ "'__getAddress' only supports groupshared variables and members of groupshared/device memory.")
DIAGNOSTIC(31200, Warning, deprecatedUsage, "$0 has been deprecated: $1")
DIAGNOSTIC(31201, Error, modifierNotAllowed, "modifier '$0' is not allowed here.")
DIAGNOSTIC(
@@ -2682,6 +2695,8 @@ DIAGNOSTIC(
"cannot perform atomic operation because destination is neither groupshared nor from a device "
"buffer.")
+DIAGNOSTIC(41404, Error, cannotWriteToReadOnlyPointer, "cannot write to a read-only pointer")
+
//
// 5xxxx - Target code generation.
//
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 2127141bd..5036333c1 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1844,7 +1844,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
(IRPtrTypeBase*)inst))
emitOpTypeForwardPointer(resultSpvType, storageClass);
}
- if (storageClass == SpvStorageClassPhysicalStorageBuffer)
+ if (storageClass == SpvStorageClassPhysicalStorageBuffer ||
+ storageClass == SpvStorageClassStorageBuffer)
{
if (m_decoratedSpvInsts.add(getID(resultSpvType)))
{
@@ -3271,6 +3272,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
builder.getPtrType(
kIROp_PtrType,
spvAsmBuiltinVar->getDataType(),
+ AccessQualifier::ReadWrite,
AddressSpace::BuiltinInput),
kind,
spvAsmBuiltinVar);
@@ -3868,20 +3870,24 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
return spvDebugLocalVar;
}
- bool isLegalType(IRInst* type)
+ // Returns true if the given type is allowed to emit for a `DebugVar`.
+ // Other types may not be illegal, but Slang currently does not support
+ // emitting these other DebugVar types.
+ bool isAllowedDebugVarType(IRInst* type)
{
switch (type->getOp())
{
case kIROp_UnsizedArrayType:
return false;
case kIROp_ArrayType:
- return isLegalType(as<IRArrayType>(type)->getElementType());
+ return isAllowedDebugVarType(as<IRArrayType>(type)->getElementType());
case kIROp_VectorType:
case kIROp_StructType:
case kIROp_MatrixType:
return true;
case kIROp_PtrType:
- return as<IRPtrTypeBase>(type)->getAddressSpace() == AddressSpace::UserPointer;
+ return as<IRPtrTypeBase>(type)->getAddressSpace() == AddressSpace::UserPointer ||
+ as<IRPtrTypeBase>(type)->getAddressSpace() == AddressSpace::GroupShared;
default:
if (as<IRBasicType>(type))
return true;
@@ -3899,7 +3905,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
builder.setInsertBefore(debugVar);
auto varType = tryGetPointedToType(&builder, debugVar->getDataType());
- if (!isLegalType(varType))
+ if (!isAllowedDebugVarType(varType))
return nullptr;
IRSizeAndAlignment sizeAlignment;
@@ -9473,10 +9479,16 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
{
if (auto ptrType = as<IRPtrTypeBase>(type))
{
- if (ptrType->getAddressSpace() == AddressSpace::StorageBuffer)
+ switch (ptrType->getAddressSpace())
{
+ case AddressSpace::StorageBuffer:
+ ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_variable_pointers"));
+ requireSPIRVCapability(SpvCapabilityVariablePointersStorageBuffer);
+ break;
+ case AddressSpace::GroupShared:
ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_variable_pointers"));
requireSPIRVCapability(SpvCapabilityVariablePointers);
+ break;
}
}
}
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index a46b20b5f..011d1fec9 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -202,7 +202,7 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(
{
auto loweredType = lowerDiffPairType(builder, basePtrType->getValueType());
- pairType = builder->getPtrType(kIROp_PtrType, (IRType*)loweredType);
+ pairType = builder->getPtrType((IRType*)loweredType);
}
else
{
diff --git a/source/slang/slang-ir-explicit-global-context.cpp b/source/slang/slang-ir-explicit-global-context.cpp
index 9fae4ce15..dd07db883 100644
--- a/source/slang/slang-ir-explicit-global-context.cpp
+++ b/source/slang/slang-ir-explicit-global-context.cpp
@@ -316,8 +316,7 @@ struct IntroduceExplicitGlobalContextPass
// The context will usually be passed around by pointer,
// so we get and cache that pointer type up front.
//
- m_contextStructPtrType =
- builder.getPtrType(kIROp_PtrType, m_contextStructType, getAddressSpaceOfLocal());
+ m_contextStructPtrType = builder.getPtrType(m_contextStructType, getAddressSpaceOfLocal());
// The first step will be to create fields in the `KernelContext`
@@ -630,7 +629,7 @@ struct IntroduceExplicitGlobalContextPass
auto ptrType = getGlobalVarPtrType(globalVar);
if (fieldInfo.needDereference)
- ptrType = builder.getPtrType(kIROp_PtrType, ptrType, getAddressSpaceOfLocal());
+ ptrType = builder.getPtrType(ptrType, getAddressSpaceOfLocal());
// We then iterate over the uses of the variable,
// being careful to defend against the use/def information
diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp
index 65b997195..203a610dd 100644
--- a/source/slang/slang-ir-glsl-legalize.cpp
+++ b/source/slang/slang-ir-glsl-legalize.cpp
@@ -1466,7 +1466,11 @@ ScalarizedVal createSimpleGLSLGlobalVarying(
// Set the array size to 0, to mean it is unsized
auto arrayType = builder->getArrayType(type, 0);
- IRType* paramType = builder->getPtrType(ptrOpCode, arrayType, addrSpace);
+ auto accessQualifier = AccessQualifier::ReadWrite;
+ if (kind == LayoutResourceKind::VaryingInput)
+ accessQualifier = AccessQualifier::Read;
+ IRType* paramType =
+ builder->getPtrType(ptrOpCode, arrayType, accessQualifier, addrSpace);
auto globalParam = addGlobalParam(builder->getModule(), paramType);
moveValueBefore(globalParam, builder->getFunc());
@@ -2558,7 +2562,7 @@ static void consolidateParameters(GLSLLegalizationContext* context, List<IRParam
// Create a global variable to hold the consolidated struct
consolidatedVar = builder->createGlobalVar(structType);
- auto ptrType = builder->getPtrType(kIROp_PtrType, structType, AddressSpace::IncomingRayPayload);
+ auto ptrType = builder->getPtrType(structType, AddressSpace::IncomingRayPayload);
consolidatedVar->setFullType(ptrType);
consolidatedVar->moveToEnd();
@@ -3088,7 +3092,8 @@ IRInst* getOrCreatePerVertexInputArray(GLSLLegalizationContext* context, IRInst*
auto arrayType = builder.getArrayType(
tryGetPointedToType(&builder, inputVertexAttr->getDataType()),
builder.getIntValue(builder.getIntType(), 3));
- arrayInst = builder.createGlobalParam(builder.getPtrType(arrayType, AddressSpace::Input));
+ arrayInst = builder.createGlobalParam(
+ builder.getPtrType(arrayType, AccessQualifier::Read, AddressSpace::Input));
context->mapVertexInputToPerVertexArray[inputVertexAttr] = arrayInst;
builder.addDecoration(arrayInst, kIROp_PerVertexDecoration);
@@ -4301,10 +4306,7 @@ void legalizeEntryPointForGLSL(
// Re-add ptr if there was one on the input
if (ptrType)
{
- sizedArrayType = builder.getPtrType(
- ptrType->getOp(),
- sizedArrayType,
- ptrType->getAddressSpace());
+ sizedArrayType = builder.getPtrType(sizedArrayType, ptrType);
}
// Change the globals type
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index c8604f4fa..9208e546c 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3721,6 +3721,7 @@ public:
IRGenericKind* getGenericKind();
IRPtrType* getPtrType(IRType* valueType);
+ IRPtrTypeBase* getPtrType(IROp op, IRType* valueType);
// Form a ptr type to `valueType` using the same opcode and address space as `ptrWithAddrSpace`.
IRPtrTypeBase* getPtrTypeWithAddressSpace(IRType* valueType, IRPtrTypeBase* ptrWithAddrSpace);
@@ -3728,14 +3729,41 @@ public:
IROutType* getOutType(IRType* valueType);
IRInOutType* getInOutType(IRType* valueType);
IRRefType* getRefType(IRType* valueType, AddressSpace addrSpace);
- IRConstRefType* getConstRefType(IRType* valueType);
IRConstRefType* getConstRefType(IRType* valueType, AddressSpace addrSpace);
- IRPtrTypeBase* getPtrType(IROp op, IRType* valueType);
- IRPtrType* getPtrType(IROp op, IRType* valueType, AddressSpace addressSpace);
- IRPtrType* getPtrType(IROp op, IRType* valueType, IRInst* addressSpace);
+ IRPtrType* getPtrType(
+ IROp op,
+ IRType* valueType,
+ AccessQualifier accessQualifier,
+ AddressSpace addressSpace);
+ IRPtrType* getPtrType(
+ IROp op,
+ IRType* valueType,
+ IRInst* accessQualifier,
+ IRInst* addressSpace);
+ IRPtrType* getPtrType(IROp op, IRType* valueType, AddressSpace addressSpace)
+ {
+ return getPtrType(op, valueType, AccessQualifier::ReadWrite, addressSpace);
+ }
+ IRPtrType* getPtrType(
+ IRType* valueType,
+ AccessQualifier accessQualifier,
+ AddressSpace addressSpace)
+ {
+ return getPtrType(kIROp_PtrType, valueType, accessQualifier, addressSpace);
+ }
IRPtrType* getPtrType(IRType* valueType, AddressSpace addressSpace)
{
- return getPtrType(kIROp_PtrType, valueType, addressSpace);
+ return getPtrType(valueType, AccessQualifier::ReadWrite, addressSpace);
+ }
+ // Copies the op-type of the oldPtrType, access-qualifier and address-space.
+ // Does not reuse the same `inst` for access-qualifier and address-space.
+ IRPtrTypeBase* getPtrType(IRType* valueType, IRPtrTypeBase* oldPtrType)
+ {
+ return getPtrType(
+ oldPtrType->getOp(),
+ valueType,
+ oldPtrType->getAccessQualifier(),
+ oldPtrType->getAddressSpace());
}
IRTextureTypeBase* getTextureType(
diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp
index 9cf8d7b5f..9363bc882 100644
--- a/source/slang/slang-ir-legalize-types.cpp
+++ b/source/slang/slang-ir-legalize-types.cpp
@@ -3699,6 +3699,7 @@ static LegalVal legalizeGlobalVar(IRTypeLegalizationContext* context, IRGlobalVa
irGlobalVar,
context->builder->getPtrType(
legalValueType.getSimple(),
+ varPtrType ? varPtrType->getAccessQualifier() : AccessQualifier::ReadWrite,
varPtrType ? varPtrType->getAddressSpace() : AddressSpace::Global));
return LegalVal::simple(irGlobalVar);
diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp
index 621c7a55e..062330836 100644
--- a/source/slang/slang-ir-legalize-varying-params.cpp
+++ b/source/slang/slang-ir-legalize-varying-params.cpp
@@ -3597,10 +3597,8 @@ protected:
IRPtrTypeBase* type = as<IRPtrTypeBase>(param->getDataType());
- const auto annotatedPayloadType = builder.getPtrType(
- kIROp_ConstRefType,
- type->getValueType(),
- AddressSpace::MetalObjectData);
+ const auto annotatedPayloadType =
+ builder.getConstRefType(type->getValueType(), AddressSpace::MetalObjectData);
param->setFullType(annotatedPayloadType);
}
diff --git a/source/slang/slang-ir-specialize-address-space.cpp b/source/slang/slang-ir-specialize-address-space.cpp
index 2bc1de775..29f1ec516 100644
--- a/source/slang/slang-ir-specialize-address-space.cpp
+++ b/source/slang/slang-ir-specialize-address-space.cpp
@@ -103,8 +103,11 @@ struct AddressSpaceContext : public AddressSpaceSpecializationContext
if (ptrType)
{
auto paramAddrSpace = key.getArgAddrSpaces()[paramIndex];
- auto newParamType =
- builder.getPtrType(ptrType->getOp(), ptrType->getValueType(), paramAddrSpace);
+ auto newParamType = builder.getPtrType(
+ ptrType->getOp(),
+ ptrType->getValueType(),
+ ptrType->getAccessQualifier(),
+ paramAddrSpace);
param->setFullType(newParamType);
mapInstToAddrSpace[param] = paramAddrSpace;
}
@@ -310,6 +313,7 @@ struct AddressSpaceContext : public AddressSpaceSpecializationContext
auto newResultType = builder.getPtrType(
ptrResultType->getOp(),
ptrResultType->getValueType(),
+ ptrResultType->getAccessQualifier(),
addrSpace);
fixUpFuncType(func, newResultType);
retValAddrSpaceChanged = true;
@@ -349,8 +353,11 @@ struct AddressSpaceContext : public AddressSpaceSpecializationContext
if (ptrType->getAddressSpace() != addrSpace)
{
IRBuilder builder(inst);
- auto newType =
- builder.getPtrType(ptrType->getOp(), ptrType->getValueType(), addrSpace);
+ auto newType = builder.getPtrType(
+ ptrType->getOp(),
+ ptrType->getValueType(),
+ ptrType->getAccessQualifier(),
+ addrSpace);
setDataType(inst, newType);
}
}
diff --git a/source/slang/slang-ir-specialize-function-call.cpp b/source/slang/slang-ir-specialize-function-call.cpp
index c03e644de..7c82891a6 100644
--- a/source/slang/slang-ir-specialize-function-call.cpp
+++ b/source/slang/slang-ir-specialize-function-call.cpp
@@ -662,12 +662,12 @@ struct FunctionParameterSpecializationContext
case kIROp_OutType:
case kIROp_RefType:
case kIROp_ConstRefType:
- argType = as<IRPtrTypeBase>(argType)->getValueType();
- resultType = getBuilder()->getPtrType(
- paramType->getOp(),
- argType,
- as<IRPtrTypeBase>(paramType)->getAddressSpace());
- break;
+ {
+ auto ptrParamType = as<IRPtrTypeBase>(paramType);
+ argType = as<IRPtrTypeBase>(argType)->getValueType();
+ resultType = getBuilder()->getPtrType(argType, ptrParamType);
+ break;
+ }
}
if (auto rate = paramType->getRate())
{
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index 9433b560b..2c4bd11cc 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -249,7 +249,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
{
builder.setInsertBefore(use->getUser());
auto addr = builder.emitFieldAddress(
- builder.getPtrType(kIROp_PtrType, innerType, AddressSpace::Uniform),
+ builder.getPtrType(innerType, AccessQualifier::Read, AddressSpace::Uniform),
cbParamInst,
key);
use->set(addr);
@@ -291,12 +291,9 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
auto basePtrType = as<IRPtrTypeBase>(addr->getDataType());
IRType* ptrType = nullptr;
if (basePtrType->hasAddressSpace())
- ptrType = builder.getPtrType(
- kIROp_PtrType,
- user->getDataType(),
- basePtrType->getAddressSpace());
+ ptrType = builder.getPtrType(user->getDataType(), basePtrType);
else
- ptrType = builder.getPtrType(kIROp_PtrType, user->getDataType());
+ ptrType = builder.getPtrType(user->getDataType());
IRInst* subAddr = nullptr;
if (user->getOp() == kIROp_GetElement)
subAddr = builder.emitElementAddress(
@@ -443,6 +440,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
return;
}
+ AccessQualifier access = AccessQualifier::ReadWrite;
// Opaque resource handles can't be in Uniform for Vulkan, if they are
// placed here then put them in UniformConstant instead
if (isSpirvUniformConstantType(inst->getDataType()))
@@ -518,7 +516,10 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
// structured buffers in GLSL should be annotated as ReadOnly
if (as<IRHLSLStructuredBufferType>(structuredBufferType))
+ {
+ access = AccessQualifier::Read;
memoryFlags = MemoryQualifierSetModifier::Flags::kReadOnly;
+ }
if (as<IRHLSLRasterizerOrderedStructuredBufferType>(structuredBufferType))
memoryFlags = MemoryQualifierSetModifier::Flags::kRasterizerOrdered;
@@ -555,7 +556,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
// Make a pointer type of storageClass.
builder.setInsertBefore(inst);
- ptrType = builder.getPtrType(kIROp_PtrType, innerType, addressSpace);
+ ptrType = builder.getPtrType(innerType, access, addressSpace);
inst->setFullType(ptrType);
if (needLoad)
{
@@ -578,7 +579,10 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
IRBuilder builder(getElement);
builder.setInsertBefore(user);
auto newAddr = builder.emitElementAddress(
- builder.getPtrType(kIROp_PtrType, innerElementType, addressSpace),
+ builder.getPtrType(
+ innerElementType,
+ ptrType->getAccessQualifier(),
+ addressSpace),
inst,
getElement->getIndex());
user->replaceUsesWith(newAddr);
@@ -714,6 +718,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
auto newPtrType = builder.getPtrType(
oldPtrType->getOp(),
oldPtrType->getValueType(),
+ oldPtrType->getAccessQualifier(),
AddressSpace::Function);
inst->setFullType(newPtrType);
addUsersToWorkList(inst);
@@ -735,9 +740,12 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
if (block == func->getFirstBlock())
{
- // A pointer typed function parameter should always be in the storage buffer address
- // space.
- addressSpace = AddressSpace::UserPointer;
+ // A pointer typed function parameter is in the storage buffer address
+ // space or groupshared.
+ if (as<IRGroupSharedRate>(inst->getRate()))
+ addressSpace = AddressSpace::GroupShared;
+ else
+ addressSpace = AddressSpace::UserPointer;
}
else
{
@@ -765,6 +773,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
auto newPtrType = builder.getPtrType(
oldPtrType->getOp(),
oldPtrType->getValueType(),
+ oldPtrType->getAccessQualifier(),
AddressSpace::UserPointer);
inst->setFullType(newPtrType);
addUsersToWorkList(inst);
@@ -785,10 +794,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
IRBuilder builder(inst);
builder.setInsertBefore(inst);
IRType* newPtrType = oldPtrType->hasAddressSpace()
- ? builder.getPtrType(
- oldPtrType->getOp(),
- newPtrValueType,
- oldPtrType->getAddressSpace())
+ ? builder.getPtrType(newPtrValueType, oldPtrType)
: builder.getPtrType(oldPtrType->getOp(), newPtrValueType);
inst->setFullType(newPtrType);
}
@@ -839,8 +845,11 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
}
IRBuilder builder(m_sharedContext->m_irModule);
builder.setInsertBefore(inst);
- auto newPtrType =
- builder.getPtrType(oldPtrType->getOp(), oldPtrType->getValueType(), addressSpace);
+ auto newPtrType = builder.getPtrType(
+ oldPtrType->getOp(),
+ oldPtrType->getValueType(),
+ oldPtrType->getAccessQualifier(),
+ addressSpace);
inst->setFullType(newPtrType);
addUsersToWorkList(inst);
return;
@@ -1022,6 +1031,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
auto newPtrType = builder.getPtrType(
oldResultType->getOp(),
oldResultType->getValueType(),
+ ptrType->getAccessQualifier(),
ptrType->getAddressSpace());
IRInst* args[2] = {base, index};
auto newInst = builder.emitIntrinsicInst(newPtrType, gepInst->getOp(), 2, args);
@@ -1075,6 +1085,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
auto newResultType = builder.getPtrType(
resultPtrType->getOp(),
resultPtrType->getValueType(),
+ ptrOperandType->getAccessQualifier(),
ptrOperandType->getAddressSpace());
auto newInst = builder.replaceOperand(&offsetPtrInst->typeUse, newResultType);
addUsersToWorkList(newInst);
@@ -1095,8 +1106,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
builder.setInsertBefore(loadInst);
IRInst* args[] = {sb, index};
auto addrInst = builder.emitIntrinsicInst(
- builder
- .getPtrType(kIROp_PtrType, loadInst->getFullType(), getStorageBufferAddressSpace()),
+ builder.getPtrType(loadInst->getFullType(), getStorageBufferAddressSpace()),
kIROp_RWStructuredBufferGetElementPtr,
2,
args);
@@ -1115,7 +1125,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
builder.setInsertBefore(storeInst);
IRInst* args[] = {sb, index};
auto addrInst = builder.emitIntrinsicInst(
- builder.getPtrType(kIROp_PtrType, value->getFullType(), getStorageBufferAddressSpace()),
+ builder.getPtrType(value->getFullType(), getStorageBufferAddressSpace()),
kIROp_RWStructuredBufferGetElementPtr,
2,
args);
@@ -1168,6 +1178,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
auto newPtrType = builder.getPtrType(
oldResultType->getOp(),
newValueType,
+ ptrType->getAccessQualifier(),
ptrType->getAddressSpace());
auto newInst =
builder.emitFieldAddress(newPtrType, inst->getBase(), inst->getField());
@@ -2172,8 +2183,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
}
// Update the global param's type to use the wrapper struct
- auto newPtrType =
- builder.getPtrType(ptrType->getOp(), wrapperStruct, ptrType->getAddressSpace());
+ auto newPtrType = builder.getPtrType(wrapperStruct, ptrType);
globalParam->setFullType(newPtrType);
// Traverse all uses of the global param and insert a FieldAddress to access the
@@ -2184,7 +2194,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
{
builder.setInsertBefore(use->getUser());
auto addr = builder.emitFieldAddress(
- builder.getPtrType(kIROp_PtrType, structType, ptrType->getAddressSpace()),
+ builder.getPtrType(structType, ptrType),
globalParam,
key);
use->set(addr);
@@ -2246,11 +2256,16 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
for (auto t : instsToProcess)
{
auto lowered = lowerStructuredBufferType(t);
+
+ AccessQualifier accessQualifier = AccessQualifier::ReadWrite;
+ if (as<IRHLSLStructuredBufferType>(t))
+ accessQualifier = AccessQualifier::Read;
+
IRBuilder builder(t);
builder.setInsertBefore(t);
t->replaceUsesWith(builder.getPtrType(
- kIROp_PtrType,
lowered.structType,
+ accessQualifier,
getStorageBufferAddressSpace()));
}
for (auto t : textureFootprintTypes)
diff --git a/source/slang/slang-ir-translate-global-varying-var.cpp b/source/slang/slang-ir-translate-global-varying-var.cpp
index 57b277d25..c899de653 100644
--- a/source/slang/slang-ir-translate-global-varying-var.cpp
+++ b/source/slang/slang-ir-translate-global-varying-var.cpp
@@ -220,8 +220,8 @@ struct GlobalVarTranslationContext
input->transferDecorationsTo(key);
// Emit a new param here to represent the global input var.
- auto inputParam = builder.emitParam(
- builder.getPtrType(kIROp_ConstRefType, inputType, AddressSpace::Input));
+ auto inputParam =
+ builder.emitParam(builder.getConstRefType(inputType, AddressSpace::Input));
// Copy the global input vars original decorations onto the new param.
// We need to do this to ensure that we can do things like get system
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 6b9273c15..ab59112f3 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2935,17 +2935,13 @@ IRInOutType* IRBuilder::getInOutType(IRType* valueType)
IRRefType* IRBuilder::getRefType(IRType* valueType, AddressSpace addrSpace)
{
- return (IRRefType*)getPtrType(kIROp_RefType, valueType, addrSpace);
-}
-
-IRConstRefType* IRBuilder::getConstRefType(IRType* valueType)
-{
- return (IRConstRefType*)getPtrType(kIROp_ConstRefType, valueType);
+ return (IRRefType*)getPtrType(kIROp_RefType, valueType, AccessQualifier::ReadWrite, addrSpace);
}
IRConstRefType* IRBuilder::getConstRefType(IRType* valueType, AddressSpace addrSpace)
{
- return (IRConstRefType*)getPtrType(kIROp_ConstRefType, valueType, addrSpace);
+ return (
+ IRConstRefType*)getPtrType(kIROp_ConstRefType, valueType, AccessQualifier::Read, addrSpace);
}
IRSPIRVLiteralType* IRBuilder::getSPIRVLiteralType(IRType* type)
@@ -2965,23 +2961,35 @@ IRPtrTypeBase* IRBuilder::getPtrTypeWithAddressSpace(
IRPtrTypeBase* ptrWithAddrSpace)
{
if (ptrWithAddrSpace->hasAddressSpace())
- return (IRPtrTypeBase*)
- getPtrType(ptrWithAddrSpace->getOp(), valueType, ptrWithAddrSpace->getAddressSpace());
+ return (IRPtrTypeBase*)getPtrType(
+ ptrWithAddrSpace->getOp(),
+ valueType,
+ ptrWithAddrSpace->getAccessQualifier(),
+ ptrWithAddrSpace->getAddressSpace());
return (IRPtrTypeBase*)getPtrType(ptrWithAddrSpace->getOp(), valueType);
}
-IRPtrType* IRBuilder::getPtrType(IROp op, IRType* valueType, AddressSpace addressSpace)
+IRPtrType* IRBuilder::getPtrType(
+ IROp op,
+ IRType* valueType,
+ AccessQualifier accessQualifier,
+ AddressSpace addressSpace)
{
return (IRPtrType*)getPtrType(
op,
valueType,
+ getIntValue(getUInt64Type(), static_cast<IRIntegerValue>(accessQualifier)),
getIntValue(getUInt64Type(), static_cast<IRIntegerValue>(addressSpace)));
}
-IRPtrType* IRBuilder::getPtrType(IROp op, IRType* valueType, IRInst* addressSpace)
+IRPtrType* IRBuilder::getPtrType(
+ IROp op,
+ IRType* valueType,
+ IRInst* accessQualifier,
+ IRInst* addressSpace)
{
- IRInst* operands[] = {valueType, addressSpace};
- return (IRPtrType*)getType(op, addressSpace ? 2 : 1, operands);
+ IRInst* operands[] = {valueType, accessQualifier, addressSpace};
+ return (IRPtrType*)getType(op, addressSpace ? 3 : 1, operands);
}
IRTextureTypeBase* IRBuilder::getTextureType(
@@ -4822,7 +4830,7 @@ IRGlobalVar* IRBuilder::createGlobalVar(IRType* valueType)
IRGlobalVar* IRBuilder::createGlobalVar(IRType* valueType, AddressSpace addressSpace)
{
- auto ptrType = getPtrType(kIROp_PtrType, valueType, addressSpace);
+ auto ptrType = getPtrType(valueType, addressSpace);
IRGlobalVar* globalVar = createInst<IRGlobalVar>(this, kIROp_GlobalVar, ptrType);
_maybeSetSourceLoc(globalVar);
addGlobalValue(this, globalVar);
@@ -5079,7 +5087,7 @@ IRVar* IRBuilder::emitVar(IRType* type)
IRVar* IRBuilder::emitVar(IRType* type, AddressSpace addressSpace)
{
- auto allocatedType = getPtrType(kIROp_PtrType, type, addressSpace);
+ auto allocatedType = getPtrType(type, addressSpace);
auto inst = createInst<IRVar>(this, kIROp_Var, allocatedType);
addInst(inst);
return inst;
@@ -5308,6 +5316,7 @@ IRType* maybePropagateAddressSpace(IRBuilder* builder, IRInst* basePtr, IRType*
type = builder->getPtrType(
resultPtrType->getOp(),
resultPtrType->getValueType(),
+ basePtrType->getAccessQualifier(),
basePtrType->getAddressSpace());
}
}
@@ -5318,10 +5327,12 @@ IRType* maybePropagateAddressSpace(IRBuilder* builder, IRInst* basePtr, IRType*
IRInst* IRBuilder::emitFieldAddress(IRInst* basePtr, IRInst* fieldKey)
{
AddressSpace addrSpace = AddressSpace::Generic;
+ AccessQualifier accessQualifier = AccessQualifier::ReadWrite;
IRInst* valueType = nullptr;
auto basePtrType = unwrapAttributedType(basePtr->getDataType());
if (auto ptrType = as<IRPtrTypeBase>(basePtrType))
{
+ accessQualifier = ptrType->getAccessQualifier();
addrSpace = ptrType->getAddressSpace();
valueType = ptrType->getValueType();
}
@@ -5344,7 +5355,7 @@ IRInst* IRBuilder::emitFieldAddress(IRInst* basePtr, IRInst* fieldKey)
}
}
SLANG_RELEASE_ASSERT(resultType);
- return emitFieldAddress(getPtrType(kIROp_PtrType, resultType, addrSpace), basePtr, fieldKey);
+ return emitFieldAddress(getPtrType(resultType, accessQualifier, addrSpace), basePtr, fieldKey);
}
IRInst* IRBuilder::emitFieldAddress(IRType* type, IRInst* base, IRInst* field)
@@ -5448,10 +5459,12 @@ IRInst* IRBuilder::emitElementAddress(IRInst* basePtr, IRIntegerValue index)
IRInst* IRBuilder::emitElementAddress(IRInst* basePtr, IRInst* index)
{
AddressSpace addrSpace = AddressSpace::Generic;
+ AccessQualifier accessQualifier = AccessQualifier::ReadWrite;
IRInst* valueType = nullptr;
auto basePtrType = unwrapAttributedType(basePtr->getDataType());
if (auto ptrType = as<IRPtrTypeBase>(basePtrType))
{
+ accessQualifier = ptrType->getAccessQualifier();
addrSpace = ptrType->getAddressSpace();
valueType = ptrType->getValueType();
}
@@ -5500,7 +5513,7 @@ IRInst* IRBuilder::emitElementAddress(IRInst* basePtr, IRInst* index)
auto inst = createInst<IRGetElementPtr>(
this,
kIROp_GetElementPtr,
- getPtrType(kIROp_PtrType, type, addrSpace),
+ getPtrType(type, accessQualifier, addrSpace),
basePtr,
index);
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index ad14edf21..d8fe51ddf 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1681,15 +1681,22 @@ struct IRPtrTypeBase : IRType
FIDDLE(baseInst())
IRType* getValueType() { return (IRType*)getOperand(0); }
+ AccessQualifier getAccessQualifier()
+ {
+ return getOperandCount() > 1
+ ? (AccessQualifier) static_cast<IRIntLit*>(getOperand(1))->getValue()
+ : AccessQualifier::ReadWrite;
+ }
+
bool hasAddressSpace()
{
- return getOperandCount() > 1 && getAddressSpace() != AddressSpace::Generic;
+ return getOperandCount() > 2 && getAddressSpace() != AddressSpace::Generic;
}
AddressSpace getAddressSpace()
{
- return getOperandCount() > 1
- ? (AddressSpace) static_cast<IRIntLit*>(getOperand(1))->getValue()
+ return getOperandCount() > 2
+ ? (AddressSpace) static_cast<IRIntLit*>(getOperand(2))->getValue()
: AddressSpace::Generic;
}
};
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 8e1f85f8e..4df778ee6 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -2067,7 +2067,14 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
auto astValueType = type->getValueType();
IRType* irValueType = lowerType(context, astValueType);
+ IRInst* accessQualifier = nullptr;
IRInst* addrSpace = nullptr;
+
+ if (auto astAccessQualifier = type->getAccessQualifier())
+ {
+ accessQualifier = getSimpleVal(context, lowerVal(context, astAccessQualifier));
+ }
+
if (auto astAddrSpace = type->getAddressSpace())
{
addrSpace = getSimpleVal(context, lowerVal(context, astAddrSpace));
@@ -2078,7 +2085,8 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
getBuilder()->getUInt64Type(),
(IRIntegerValue)AddressSpace::Generic);
}
- return getBuilder()->getPtrType(kIROp_PtrType, irValueType, addrSpace);
+
+ return getBuilder()->getPtrType(kIROp_PtrType, irValueType, accessQualifier, addrSpace);
}
IRType* visitDeclRefType(DeclRefType* type)
@@ -3437,7 +3445,6 @@ void _lowerFuncDeclBaseTypeInfo(
auto& parameterLists = outInfo.parameterLists;
collectParameterLists(
context,
-
declRef,
&parameterLists,
kParameterListCollectMode_Default,
@@ -3469,7 +3476,7 @@ void _lowerFuncDeclBaseTypeInfo(
irParamType = builder->getRefType(irParamType, AddressSpace::Generic);
break;
case kParameterDirection_ConstRef:
- irParamType = builder->getConstRefType(irParamType);
+ irParamType = builder->getConstRefType(irParamType, AddressSpace::Generic);
break;
default:
SLANG_UNEXPECTED("unknown parameter direction");
@@ -4157,6 +4164,39 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>
ASTBuilder* getASTBuilder() { return context->astBuilder; }
LoweredValInfo lowerSubExpr(Expr* expr) { return sharedLoweringContext.lowerSubExpr(expr); }
+ LoweredValInfo visitAddressOfExpr(AddressOfExpr* expr)
+ {
+ auto loweredType = lowerType(context, expr->type);
+ auto baseVal = lowerLValueExpr(context, expr->arg);
+ auto ptr = tryGetAddress(context, baseVal, TryGetAddressMode::Aggressive);
+
+ switch (ptr.flavor)
+ {
+ case LoweredValInfo::Flavor::Ptr:
+ {
+ // TODO: This is a hack. We should just be returning `ptr`. We do not do this since
+ // `ptr` may have the wrong address space. This happens since when lowering-to-ir we
+ // don't check what addres-space info we should be using for variables we create.
+ // example: `groupshared int ptr` ==> lower-to-ir lowers as default address-space
+ // with groupshared-rate.
+ //
+ // We need to emit a temporary variable (and cannot emit a cast) since `operator*`
+ // has its own hacks and is an incorrect implementation of its own. To elaborate,
+ // `operator*` is defined as `__intrinsic_op(0)`, which means "pass arguments
+ // through a function `in`, then set as result". This is an issue since this means
+ // that our function (which should be returning a `ref`) may in fact, not be
+ // returning a `ref` but instead be loading via the `in` parameter and generating a
+ // non-pointer result.
+ auto irVar = context->irBuilder->emitVar(loweredType);
+ context->irBuilder->emitStore(irVar, ptr.val);
+ return LoweredValInfo::ptr(irVar);
+ }
+ default:
+ SLANG_UNIMPLEMENTED_X("cannot get address of __getAddress(...) argument");
+ UNREACHABLE_RETURN(LoweredValInfo());
+ }
+ }
+
LoweredValInfo visitIncompleteExpr(IncompleteExpr*)
{
SLANG_UNEXPECTED("a valid ast should not contain an IncompleteExpr.");
diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp
index ea620ebb2..7a5665905 100644
--- a/source/slang/slang-mangle.cpp
+++ b/source/slang/slang-mangle.cpp
@@ -191,6 +191,16 @@ void emitBaseType(ManglingContext* context, BaseType baseType)
case BaseType::IntPtr:
emitRaw(context, "ip");
break;
+ case BaseType::AddressSpace:
+ emitRaw(context, "as");
+ break;
+ case BaseType::AccessQualifier:
+ emitRaw(context, "aq");
+ break;
+ case BaseType::MemoryScope:
+ emitRaw(context, "mem");
+ break;
+
default:
SLANG_UNEXPECTED("unimplemented case in base type mangling");
break;
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 1302975df..e71b6162c 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -7026,7 +7026,6 @@ static NodeBase* parseSizeOfExpr(Parser* parser, void* /*userData*/)
static NodeBase* parseAlignOfExpr(Parser* parser, void* /*userData*/)
{
- // We could have a type or a variable or an expression
AlignOfExpr* alignOfExpr = parser->astBuilder->create<AlignOfExpr>();
parser->ReadMatchingToken(TokenType::LParent);
@@ -7058,6 +7057,17 @@ static NodeBase* parseCountOfExpr(Parser* parser, void* /*userData*/)
return countOfExpr;
}
+static NodeBase* parseAddressOfExpr(Parser* parser, void* /*userData*/)
+{
+ // We could have a type or a variable or an expression
+ AddressOfExpr* addressOfExpr = parser->astBuilder->create<AddressOfExpr>();
+
+ parser->ReadMatchingToken(TokenType::LParent);
+ addressOfExpr->arg = parser->ParseExpression();
+ parser->ReadMatchingToken(TokenType::RParent);
+ return addressOfExpr;
+}
+
static NodeBase* parseTryExpr(Parser* parser, void* /*userData*/)
{
auto tryExpr = parser->astBuilder->create<TryExpr>();
@@ -9648,6 +9658,7 @@ static const SyntaxParseInfo g_parseSyntaxEntries[] = {
_makeParseExpr("sizeof", parseSizeOfExpr),
_makeParseExpr("alignof", parseAlignOfExpr),
_makeParseExpr("countof", parseCountOfExpr),
+ _makeParseExpr("__getAddress", parseAddressOfExpr),
};
ConstArrayView<SyntaxParseInfo> getSyntaxParseInfos()
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index 67d562f0f..d9eb884f0 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -927,7 +927,7 @@ FuncType* getFuncType(ASTBuilder* astBuilder, DeclRef<CallableDecl> const& declR
}
if (paramDecl->findModifier<RefModifier>())
{
- paramType = astBuilder->getRefType(paramType, AddressSpace::Generic);
+ paramType = astBuilder->getRefType(paramType);
}
else if (paramDecl->findModifier<ConstRefModifier>())
{
diff --git a/source/slang/slang-type-system-shared.h b/source/slang/slang-type-system-shared.h
index d7bd43122..3390c3b80 100644
--- a/source/slang/slang-type-system-shared.h
+++ b/source/slang/slang-type-system-shared.h
@@ -22,6 +22,10 @@ namespace Slang
X(Char) \
X(IntPtr) \
X(UIntPtr) \
+ X(CountOfPrimitives) \
+ X(AddressSpace) \
+ X(MemoryScope) \
+ X(AccessQualifier) \
/* end */
enum class BaseType
@@ -114,6 +118,26 @@ enum class AddressSpace : uint64_t
// Default address space for a user-defined pointer
UserPointer = 0x100000001ULL,
};
+
+// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_scope_id
+// must be 32 bit to match SPIR-V
+enum class MemoryScope : int32_t
+{
+ CrossDevice = 0,
+ Device = 1,
+ Workgroup = 2,
+ Subgroup = 3,
+ Invocation = 4,
+ QueueFamily = 5,
+ ShaderCall = 6,
+};
+
+enum class AccessQualifier : uint64_t
+{
+ ReadWrite = 0,
+ Read = 1,
+};
+
} // namespace Slang
#endif
diff --git a/source/slang/slang.natvis b/source/slang/slang.natvis
index edd34f456..97f4689f1 100644
--- a/source/slang/slang.natvis
+++ b/source/slang/slang.natvis
@@ -260,7 +260,8 @@
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::AndTypeExpr">(Slang::AndTypeExpr*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ModifiedTypeExpr">(Slang::ModifiedTypeExpr*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::PointerTypeExpr">(Slang::PointerTypeExpr*)&amp;astNodeType</ExpandedItem>
- <Item Name="[type]">type</Item>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::AddressOfExpr">(Slang::AddressOfExpr*)&amp;astNodeType</ExpandedItem>
+ <Item Name="[type]">type</Item>
<Item Name="[Expr]">(Slang::Expr*)this,!</Item>
</Expand>
</Type>
@@ -484,11 +485,12 @@
</Expand>
</Type>
<Type Name="Slang::ValNodeOperand">
- <DisplayString Optional="true" Condition="kind==Slang::ValNodeOperandKind::ConstantValue">Const({values.intOperand})#{_debugUID}</DisplayString>
+ <DisplayString Condition="kind==Slang::ValNodeOperandKind::ConstantValue">ConstantValue ({this->values.intOperand}) #{((Val*)this)->_debugUID}</DisplayString>
<DisplayString Condition="kind==Slang::ValNodeOperandKind::ValNode">{*(Val*)values.nodeOperand}</DisplayString>
<DisplayString>{values.nodeOperand}</DisplayString>
<Expand>
- <ExpandedItem Condition="kind==Slang::ValNodeOperandKind::ValNode">*(Val*)values.nodeOperand</ExpandedItem>
+ <ExpandedItem Condition="kind==Slang::ValNodeOperandKind::ConstantValue">values</ExpandedItem>
+ <ExpandedItem Condition="kind==Slang::ValNodeOperandKind::ValNode">*(Val*)values.nodeOperand</ExpandedItem>
<ExpandedItem Condition="kind==Slang::ValNodeOperandKind::ASTNode">*(Decl*)values.nodeOperand</ExpandedItem>
</Expand>
</Type>
diff --git a/tests/autodiff/get-offset-ptr.slang b/tests/autodiff/get-offset-ptr.slang
index 517acb54d..e497f1e48 100644
--- a/tests/autodiff/get-offset-ptr.slang
+++ b/tests/autodiff/get-offset-ptr.slang
@@ -1,40 +1,32 @@
-//TEST:SIMPLE(filecheck=CHECK): -target cuda -line-directive-mode none
+//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -cuda -output-using-type
-//CHECK: struct s_bwd_prop_function_Intermediates{{[_0-9]+}}
-//CHECK: {
-//CHECK: MyDiffPtr{{[_0-9]+}} {{[_A-Za-z0-9]+}};
-//CHECK: MyDiffPtr{{[_0-9]+}} {{[_A-Za-z0-9]+}};
-//CHECK: };
+// This test just ensures that we compile and run the code.
+// It does not check the correctness of the autodiff.
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out, name outputBuffer
RWStructuredBuffer<float> outputBuffer;
struct MyDiffPtr
{
- uint offset;
- uint d_offset;
-
- [BackwardDerivative(__bwd_foo)]
- float foo()
- {
- return outputBuffer[offset] * outputBuffer[offset];
- }
-
- void __bwd_foo(float grad)
- {
- outputBuffer[d_offset] = 2.f * outputBuffer[offset] * grad;
- }
+ float data1;
+ float data2;
};
[Differentiable]
-float function(MyDiffPtr *i)
+float function(Ptr<MyDiffPtr, Access::ReadWrite, AddressSpace::GroupShared> i)
{
- return i[0].foo() + i[1].foo();
+ return i[0].data1 + i[1].data2;
}
+groupshared MyDiffPtr s[2];
[numthreads(1, 1, 1), shader("compute")]
-void main(uint3 dispatchThreadID: SV_DispatchThreadID)
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
{
- MyDiffPtr s[2] = {{0, 2}, {1, 3}};
- __bwd_diff(function)(&s[0], 1.0f);
+ s = { { 0, 2 }, { 1, 3 } };
+ float result = 1.0f;
+ let pair = __fwd_diff(function)(__getAddress(s[0]));
+ outputBuffer[0] = pair.getPrimal();
+ outputBuffer[1] = pair.getDifferential();
+ // CHECK: 3.0
+ // CHECK-NEXT: 0.0
} \ No newline at end of file
diff --git a/tests/bugs/gh-3601.slang b/tests/bugs/gh-3601.slang
index 5d545262b..65245f971 100644
--- a/tests/bugs/gh-3601.slang
+++ b/tests/bugs/gh-3601.slang
@@ -4,7 +4,7 @@ struct TestStruct
uint index;
};
-[[vk::binding(2, 0)]] StructuredBuffer<uint64_t> test;
+[[vk::binding(2, 0)]] uniform uint64_t* test;
struct PP
{
@@ -28,15 +28,15 @@ int* funcThatReturnsPointer(PP* p)
// CHECK: OpEntryPoint
-[[vk::binding(0, 0)]] StructuredBuffer<Data> buffer;
+[[vk::binding(0, 0)]] uniform Data* buffer;
[[vk::binding(1, 0)]] RWStructuredBuffer<int> output;
[shader("compute")]
[numthreads(8, 8, 1)]
void main(int id : SV_DispatchThreadID)
{
- TestStruct * ptr = (TestStruct *)(test[0]);
+ TestStruct* ptr = (TestStruct*)(test[0]);
output[0] = buffer[ptr.index].pNext.data;
- let pData = &(buffer[0].pNext.data);
+ let pData = __getAddress(buffer[0].pNext.data);
// CHECK: OpPtrAccessChain
int* pData1 = pData + 1;
*pData1 = 3;
diff --git a/tests/diagnostics/invalid-constant-pointer-taking.slang b/tests/diagnostics/invalid-constant-pointer-taking.slang
index 349f8cc25..658a84b1b 100644
--- a/tests/diagnostics/invalid-constant-pointer-taking.slang
+++ b/tests/diagnostics/invalid-constant-pointer-taking.slang
@@ -1,4 +1,4 @@
-//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): -stage compute -entry computeMain -target spirv
+//TEST:SIMPLE(filecheck=CHECK): -stage compute -entry computeMain -target spirv
RWStructuredBuffer<float> mutable_float_buffer;
RWStructuredBuffer<uint> mutable_uint_buffer;
@@ -6,18 +6,24 @@ RWStructuredBuffer<uint> mutable_uint_buffer;
StructuredBuffer<float> constant_float_buffer;
StructuredBuffer<uint> constant_uint_buffer;
+// We do not allow taking a pointer from a StructuredBuffer/RWStructuredBuffer.
[shader("compute")]
[numthreads(1,1,1)]
void computeMain(uint3 threadId : SV_DispatchThreadID)
{
- float* mutablePtr = &mutable_float_buffer[threadId.x];
+ float* mutablePtr1 = &mutable_float_buffer[threadId.x];
+
+ // CHECK: ([[# @LINE+1]]): error 31160
+ float* mutablePtr2 = __getAddress(mutable_float_buffer[threadId.x]);
InterlockedAdd(mutable_uint_buffer[threadId.x], 1);
// Constant pointers arent a thing in slang
- // CHECK: error 30078:
- float* ptr = &constant_float_buffer[threadId.x];
+ // CHECK: ([[# @LINE+1]]): error 30079:
+ float* ptr1 = &constant_float_buffer[threadId.x];
+ // CHECK: ([[# @LINE+1]]): error 31160
+ float* ptr2 = __getAddress(constant_float_buffer[threadId.x]);
InterlockedAdd(constant_uint_buffer[0], 1);
-} \ No newline at end of file
+}
diff --git a/tests/language-feature/bitfield/msvc-repr-mixed.slang b/tests/language-feature/bitfield/msvc-repr-mixed.slang
index 47f03ad1d..cf1925dd6 100644
--- a/tests/language-feature/bitfield/msvc-repr-mixed.slang
+++ b/tests/language-feature/bitfield/msvc-repr-mixed.slang
@@ -19,20 +19,23 @@ struct MixedSizes {
uint16_t d : 8; // Same backing field
};
+groupshared MixedSizes m;
+
+typealias GroupSharedPtr<T> = Ptr<T, Access::ReadWrite, AddressSpace::GroupShared>;
+
[numthreads(1, 1, 1)]
void computeMain()
{
- MixedSizes m;
m.a = 0xA;
m.b = 0xB;
m.c = 0xCD;
m.d = 0xEF;
// Read the two backing fields separately
- uint8_t* p8 = (uint8_t*)&m;
- uint16_t* p16 = (uint16_t*)((uint8_t*)&m + 2); // Skip uint8_t + padding
+ GroupSharedPtr<uint8_t> p8 = (GroupSharedPtr<uint8_t>)__getAddress(m);
+ GroupSharedPtr<uint16_t> p16 = (GroupSharedPtr<uint16_t>)( ((GroupSharedPtr<uint8_t>)__getAddress(m)) + 2); // Skip uint8_t + padding
- outputBuffer[0] = uint(*p8);
- outputBuffer[1] = uint(*p16);
+ outputBuffer[0] = (uint)*p8;
+ outputBuffer[1] = (uint)*p16;
}
diff --git a/tests/language-feature/capability/address-of.slang b/tests/language-feature/capability/address-of.slang
new file mode 100644
index 000000000..924312b0e
--- /dev/null
+++ b/tests/language-feature/capability/address-of.slang
@@ -0,0 +1,17 @@
+//TEST:SIMPLE(filecheck=CHECK_FAIL): -target glsl -entry computeMain -stage compute
+//TEST:SIMPLE(filecheck=CHECK_PASS): -target spirv -entry computeMain -stage compute
+
+// Test that __getAddress correctly reports capabilities.
+
+uniform int* outputBuffer;
+uniform int* buffer;
+
+// CHECK_PASS: OpEntryPoint
+// CHECK_PASS-NOT: error
+
+// CHECK_FAIL: ([[# @LINE+1]]): error 36107{{.*}}glsl
+void computeMain()
+{
+ // CHECK: ([[# @LINE+1]]): note: see using of '__getAddress'
+ outputBuffer[0] = *(__getAddress(buffer[0]));
+} \ No newline at end of file
diff --git a/tests/language-feature/pointer/const-ptr-variations.slang b/tests/language-feature/pointer/const-ptr-variations.slang
new file mode 100644
index 000000000..a2619d6c4
--- /dev/null
+++ b/tests/language-feature/pointer/const-ptr-variations.slang
@@ -0,0 +1,40 @@
+//TEST:SIMPLE(filecheck=CHECK_1):-stage compute -entry computeMain -target spirv -DT1
+//TEST:SIMPLE(filecheck=CHECK_2):-stage compute -entry computeMain -target spirv -DT2
+//TEST:SIMPLE(filecheck=CHECK_3):-stage compute -entry computeMain -target spirv -DT3
+//TEST:SIMPLE(filecheck=CHECK_4):-stage compute -entry computeMain -target spirv -DT4
+
+// Tests for invalid use of `const` with Ptr/T*
+// Due to bad syntax breaking the parser, it is more robust to use disjoint tests with
+// #define's.
+cbuffer Globals
+{
+ int* ptr;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(int id : SV_DispatchThreadID)
+{
+ // disallowed syntax with modifier `const`
+#ifdef T1
+ // CHECK_1: ([[# @LINE+1]]): error
+ int const* ptr1 = ptr;
+#endif
+
+#ifdef T2
+ // CHECK_2: ([[# @LINE+1]]): error
+ int* const ptr2 = ptr;
+#endif
+
+#ifdef T3
+ // CHECK_3: ([[# @LINE+1]]): error 20017
+ const int* ptr3 = ptr;
+ // CHECK_3: ([[# @LINE+1]]): error 20018
+ Ptr<const int> ptr4 = ptr;
+#endif
+
+#ifdef T4
+ // CHECK_4: OpEntryPoint
+ // CHECK_4-NOT: error
+ const Ptr<int> ptr5 = ptr;
+#endif
+} \ No newline at end of file
diff --git a/tests/language-feature/pointer/get-address-validation.slang b/tests/language-feature/pointer/get-address-validation.slang
new file mode 100644
index 000000000..3931c13a2
--- /dev/null
+++ b/tests/language-feature/pointer/get-address-validation.slang
@@ -0,0 +1,82 @@
+//TEST:SIMPLE(filecheck=CHECK):-stage compute -entry computeMain -target spirv
+
+// Tests for invalid/valid use of `__getAddress`
+
+struct DeviceStruct
+{
+ int data1;
+ int data2;
+}
+
+struct StructPtrInStruct
+{
+ DeviceStruct* ptr;
+}
+
+uniform int* bufferUserPointer;
+RWStructuredBuffer<int> bufferStorage;
+groupshared int bufferGroupShared[100];
+uniform DeviceStruct* bufferUserPointerStruct;
+uniform int2* bufferUserPointerVector;
+
+int* output;
+
+typealias GroupSharedPtr<T> = Ptr<T, Access::ReadWrite, AddressSpace::GroupShared>;
+
+GroupSharedPtr<T> paramGroupShared<T : __BuiltinIntegerType>(out groupshared T[100] ptr)
+{
+ // CHECK: ([[# @LINE+1]]): error 30019
+ T* ptr1 = __getAddress(ptr[5]);
+
+ // CHECK-NOT: ([[# @LINE+1]]): error
+ GroupSharedPtr<T> ptr2 = __getAddress(ptr[5]);
+
+ return ptr2;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(int id : SV_DispatchThreadID)
+{
+ // CHECK: ([[# @LINE+1]]): error 31160
+ int* ptr1 = __getAddress(bufferStorage[id.x]);
+
+ // CHECK ([[# @LINE+1]]): error
+ int[100]* ptr2 = __getAddress(bufferGroupShared);
+
+ // CHECK: ([[# @LINE+1]]): error
+ int* ptr3 = __getAddress(bufferGroupShared[id.x]);
+
+ // CHECK-NOT: ([[# @LINE+1]]): error
+ int* ptr4 = __getAddress(bufferUserPointer[id.x]);
+
+ // CHECK-NOT: ([[# @LINE+1]]): error
+ GroupSharedPtr<int[100]> ptr5 = __getAddress(bufferGroupShared);
+
+ // CHECK-NOT: ([[# @LINE+1]]): error
+ GroupSharedPtr<int> ptr6 = __getAddress(bufferGroupShared[id.x]);
+
+ // CHECK-NOT: ([[# @LINE+1]]): error
+ GroupSharedPtr<int> ptr7 = paramGroupShared(bufferGroupShared);
+
+ // CHECK-NOT: ([[# @LINE+1]]): error
+ int* ptr8 = __getAddress(bufferUserPointerStruct.data1);
+
+ StructPtrInStruct structPtrInStruct;
+ structPtrInStruct.ptr = bufferUserPointerStruct;
+ // CHECK-NOT: ([[# @LINE+1]]): error
+ int* ptr9 = __getAddress(structPtrInStruct.ptr[id.x].data1);
+
+ // CHECK-NOT: ([[# @LINE+1]]): error
+ int* ptr10 = __getAddress(bufferUserPointerVector[0].x);
+
+ output[id] = ptr1[id];
+ output[id] = ptr2[id][0];
+ output[id] = ptr3[id];
+ output[id] = ptr4[id];
+ output[id] = ptr5[id];
+ output[id] = ptr6[id];
+ output[id] = ptr7[id];
+ output[id] = ptr8[id];
+ output[id] = ptr9[id];
+ output[id] = ptr10[id];
+}
diff --git a/tests/language-feature/pointer/globallycoherent-ptr.slang b/tests/language-feature/pointer/globallycoherent-ptr.slang
new file mode 100644
index 000000000..4909537d7
--- /dev/null
+++ b/tests/language-feature/pointer/globallycoherent-ptr.slang
@@ -0,0 +1,20 @@
+//TEST:SIMPLE(filecheck=CHECK):-stage compute -entry computeMain -target spirv
+
+// Tests for invalid use of `globallycoherent` with Ptr/T*
+
+cbuffer Globals
+{
+ // CHECK: ([[# @LINE+1]]): error 30078
+ globallycoherent Ptr<int> ptr1;
+ // CHECK: ([[# @LINE+1]]): error 30078
+ globallycoherent int* ptr2;
+ // CHECK: ([[# @LINE+1]]): error 30078
+ coherent Ptr<int> ptr3;
+ // CHECK: ([[# @LINE+1]]): error 30078
+ coherent int* ptr4;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(int id : SV_DispatchThreadID)
+{
+}
diff --git a/tests/language-feature/pointer/groupshared-ptr-of-device.slang b/tests/language-feature/pointer/groupshared-ptr-of-device.slang
new file mode 100644
index 000000000..31703819e
--- /dev/null
+++ b/tests/language-feature/pointer/groupshared-ptr-of-device.slang
@@ -0,0 +1,28 @@
+//TEST:SIMPLE(filecheck=SPIRV):-stage compute -entry computeMain -target spirv -capability vk_mem_model+sm_6_0+spvGroupNonUniformBallot
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -capability vk_mem_model+sm_6_0+spvGroupNonUniformBallot
+
+// Tests if we pass-through and handle pointers via groupshared-memory correctly.
+// Ensure SPIRV emits coherent operations here
+// SPIRV: OpEntryPoint
+// SPIRV-NOT: error
+
+// CHECK: 1
+// CHECK-NEXT: 0
+// CHECK-NEXT: 2
+// CHECK-NEXT: 0
+// CHECK-NEXT: 3
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+uniform int* outputBuffer;
+
+groupshared int* sharedPtr[3];
+
+[numthreads(3, 1, 1)]
+void computeMain(uint3 group_thread_id: SV_GroupThreadID)
+{
+ sharedPtr[group_thread_id.x] = outputBuffer + group_thread_id.x;
+ sharedPtr[group_thread_id.x] = sharedPtr[group_thread_id.x]+group_thread_id.x;
+ GroupMemoryBarrierWithGroupSync();
+
+ *sharedPtr[group_thread_id.x] = group_thread_id.x+1;
+} \ No newline at end of file
diff --git a/tests/language-feature/pointer/pointer-access/pointer-access-frontend.slang b/tests/language-feature/pointer/pointer-access/pointer-access-frontend.slang
new file mode 100644
index 000000000..98a2a076d
--- /dev/null
+++ b/tests/language-feature/pointer/pointer-access/pointer-access-frontend.slang
@@ -0,0 +1,14 @@
+//TEST:SIMPLE(filecheck=CHECK):-stage compute -entry computeMain -target spirv
+
+//CHECK: OpEntryPoint
+//CHECK-NOT: error
+
+int* processMemory;
+int* output;
+
+[numthreads(1, 1, 1)]
+void computeMain(int id : SV_DispatchThreadID)
+{
+ Ptr<int, Access::ReadWrite, AddressSpace::Device> ptr1 = processMemory + id.x + 5;
+ Ptr<int> ptr2 = processMemory + id.x + 4;
+} \ No newline at end of file
diff --git a/tests/language-feature/pointer/pointer-access/read-only-pointer-1.slang b/tests/language-feature/pointer/pointer-access/read-only-pointer-1.slang
new file mode 100644
index 000000000..e7f1ad534
--- /dev/null
+++ b/tests/language-feature/pointer/pointer-access/read-only-pointer-1.slang
@@ -0,0 +1,41 @@
+//TEST:SIMPLE(filecheck=CHECK):-stage compute -entry computeMain -target spirv
+
+// Writing with a read-only pointer should be an error
+
+int* processMemory;
+RWStructuredBuffer<int> output;
+
+typealias ReadPtr = Ptr<int, Access::Read, AddressSpace::Device>;
+
+void writeToReadOnlyPointer(ReadPtr ptr)
+{
+ // CHECK: ([[# @LINE+1]]): error 30011
+ ptr[0] = 1;
+}
+
+void writeToReadOnlyPointerOut(out int ptrVal)
+{
+ ptrVal = 1;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(int id : SV_DispatchThreadID)
+{
+ // CHECK-NOT: ([[# @LINE+1]]): error
+ ReadPtr ptr1 = ReadPtr(processMemory + id.x);
+
+ // CHECK: ([[# @LINE+1]]): error 30011
+ ptr1[id + 1] = 1;
+ // CHECK: ([[# @LINE+1]]): error 30011
+ *ptr1 = 1;
+
+ writeToReadOnlyPointer(ptr1);
+
+ // CHECK: ([[# @LINE+1]]): error 30047
+ writeToReadOnlyPointerOut(ptr1[1]);
+
+ // CHECK: ([[# @LINE+1]]): error 30047
+ writeToReadOnlyPointerOut(*(ptr1+2));
+
+ output[id] = ptr1[id];
+}
diff --git a/tests/language-feature/pointer/pointer-access/read-only-pointer-2.slang b/tests/language-feature/pointer/pointer-access/read-only-pointer-2.slang
new file mode 100644
index 000000000..c5288caba
--- /dev/null
+++ b/tests/language-feature/pointer/pointer-access/read-only-pointer-2.slang
@@ -0,0 +1,19 @@
+//TEST:SIMPLE(filecheck=CHECK):-stage compute -entry computeMain -target spirv
+
+// Tests valid use of read-only pointer
+
+// CHECK: OpEntryPoint
+// CHECK-NOT: error
+
+int* processMemory;
+int* output;
+
+typealias ReadPtr = Ptr<int, Access::Read, AddressSpace::Device>;
+
+[numthreads(1, 1, 1)]
+void computeMain(int id : SV_DispatchThreadID)
+{
+ ReadPtr ptr1 = ReadPtr(processMemory + id.x);
+
+ output[id] = ptr1[id];
+}
diff --git a/tests/language-feature/pointer/pointer-casting/pointer-casting-rules.slang b/tests/language-feature/pointer/pointer-casting/pointer-casting-rules.slang
new file mode 100644
index 000000000..d0a016fe5
--- /dev/null
+++ b/tests/language-feature/pointer/pointer-casting/pointer-casting-rules.slang
@@ -0,0 +1,51 @@
+//TEST:SIMPLE(filecheck=CHECK):-stage compute -entry computeMain -target spirv
+
+// Tests pointer casting rules: Only explicit casting is allowed between pointer types.
+// All implicit conversions between pointer types should fail.
+int* processMemory;
+RWStructuredBuffer<int> output;
+
+[numthreads(1, 1, 1)]
+void computeMain(int id : SV_DispatchThreadID)
+{
+ // regular address-of
+ // CHECK-NOT: ([[# @LINE+1]]): error
+ Ptr<int, Access::ReadWrite, AddressSpace::Device> rwPtr = processMemory + id.x;
+ // copying a pointer of T* syntax
+ // CHECK-NOT: ([[# @LINE+1]]): error
+ Ptr<int, Access::ReadWrite, AddressSpace::Device> copiedPtrOfLegacySyntax = processMemory;
+ // casting to Read ptr
+ // CHECK-NOT: ([[# @LINE+1]]): error
+ Ptr<int, Access::Read> rPtr = Ptr<int, Access::Read>(processMemory + id.x);
+
+ // casting to RW ptr from a R ptr
+ // CHECK-NOT: ([[# @LINE+1]]): error
+ Ptr<int, Access::ReadWrite, AddressSpace::Device> p1 = Ptr<int, Access::ReadWrite, AddressSpace::Device>(rPtr);
+ // casting to R ptr from a RW ptr
+ // CHECK-NOT: ([[# @LINE+1]]): error
+ Ptr<int, Access::Read, AddressSpace::Device> p2 = Ptr<int, Access::Read, AddressSpace::Device>(rwPtr);
+ // casting to ptr of different type
+ // CHECK-NOT: ([[# @LINE+1]]): error
+ Ptr<float, Access::ReadWrite, AddressSpace::Device> p3 = Ptr<float, Access::ReadWrite, AddressSpace::Device>(rPtr);
+
+ // Cannot implicit cast ptr's
+ // CHECK: ([[# @LINE+1]]): error 30019
+ Ptr<float, Access::ReadWrite, AddressSpace::Device> p4 = rPtr;
+ // cannot implcitly cast between different access qualifiers
+ // CHECK: ([[# @LINE+1]]): error 30019
+ Ptr<int, Access::Read> p5 = Ptr<int, Access::ReadWrite>(processMemory + id.x);
+ // cannot implcitly cast between different access qualifiers
+ // CHECK: ([[# @LINE+1]]): error 30019
+ Ptr<int, Access::ReadWrite> p6 = Ptr<int, Access::Read>(processMemory + id.x);
+
+ // TODO: Enable this when we allow user-defined group-shared address space, Issue #8173.
+ // Cannot cast between different address spaces.
+ // CHECK: ([[# @LINE+1]]): error
+ Ptr<float, Access::ReadWrite, AddressSpace::GroupShared> p7 = Ptr<float, Access::ReadWrite, AddressSpace::GroupShared>(rwPtr);
+ // CHECK: ([[# @LINE+1]]): error
+ Ptr<float, Access::ReadWrite, AddressSpace::GroupShared> p8 = Ptr<float, Access::ReadWrite, AddressSpace::GroupShared>(p1);
+ // CHECK: ([[# @LINE+1]]): error
+ Ptr<float, Access::ReadWrite, AddressSpace::GroupShared> p9 = rwPtr;
+
+ output[id] = *rwPtr;
+}
diff --git a/tests/language-feature/pointer/pointer-self-reference.slang b/tests/language-feature/pointer/pointer-self-reference.slang
index e78b70db0..75ff4e7a9 100644
--- a/tests/language-feature/pointer/pointer-self-reference.slang
+++ b/tests/language-feature/pointer/pointer-self-reference.slang
@@ -1,6 +1,8 @@
// pointer-self-reference.slang
-//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
+// We are disabling this test because '&' is intentionally not supported.
+// Design for pointers in Slang are not yet finalized.
+//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<int> outputBuffer;
@@ -18,13 +20,13 @@ void computeMain(int3 dispatchThreadID: SV_DispatchThreadID)
Thing things[2];
- things[0].next = &things[1];
+ things[0].next = __getAddress(things[1]);
things[0].value = 27;
- things[1].next = &things[0];
+ things[1].next = __getAddress(things[0]);
things[1].value = idx * idx;
- Ptr<Thing> cur = &things[0];
+ Ptr<Thing> cur = __getAddress(things[0]);
for (int i = 0; cur && i < idx; ++i)
{
diff --git a/tests/language-feature/pointer/ptr-to-groupshared.slang b/tests/language-feature/pointer/ptr-to-groupshared.slang
new file mode 100644
index 000000000..3ad4c5e0b
--- /dev/null
+++ b/tests/language-feature/pointer/ptr-to-groupshared.slang
@@ -0,0 +1,30 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly
+
+// Tests if we handle passing groupshared address-space pointers correctly to a function
+// when that data-type needs legalization (Data -> Data_natural due to `lower-buffer-element-type`).
+// CHECK: 1
+// CHECK-NEXT: 2
+// CHECK-NEXT: 0
+
+struct Data
+{
+ int value1;
+ int value2;
+}
+
+//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer
+uniform int* outputBuffer;
+groupshared Data shared;
+
+void foo(Ptr<Data, Access::ReadWrite, AddressSpace::GroupShared> ptr)
+{
+ outputBuffer[0] = ptr.value1;
+ outputBuffer[1] = ptr.value2;
+}
+
+[numthreads(3, 1, 1)]
+void computeMain(uint3 group_thread_id: SV_GroupThreadID)
+{
+ shared = Data(1, 2);
+ foo(__getAddress(shared));
+} \ No newline at end of file
diff --git a/tests/spirv/pointer-from-user-guide.slang b/tests/spirv/pointer-from-user-guide.slang
index 662579c2b..115530b2b 100644
--- a/tests/spirv/pointer-from-user-guide.slang
+++ b/tests/spirv/pointer-from-user-guide.slang
@@ -15,7 +15,7 @@ float test(MyType* pObj)
{
//SPV: OpTypePointer
MyType* pNext = pObj + 1;
- MyType* pNext2 = &pNext[1];
+ MyType* pNext2 = __getAddress(pNext[1]);
return pNext.a + pNext->a + (*pNext2).a + pNext2[0].a;
}
diff --git a/tests/spirv/pointer.slang b/tests/spirv/pointer.slang
index affc52e1b..f3b086e6a 100644
--- a/tests/spirv/pointer.slang
+++ b/tests/spirv/pointer.slang
@@ -28,12 +28,12 @@ void funcWithInOutParam(inout PP p)
// CHECK: OpEntryPoint
-StructuredBuffer<Data> buffer;
+uniform Data* buffer;
RWStructuredBuffer<int> output;
void main(int id : SV_DispatchThreadID)
{
output[0] = buffer[0].pNext.data;
- let pData = &(buffer[0].pNext->data); // operator -> is also allowed on pointer types.
+ let pData = __getAddress(buffer[0].pNext->data); // operator -> is also allowed on pointer types.
// CHECK: OpPtrAccessChain
int* pData1 = pData + 1;
*pData1 = 3;
diff --git a/tests/spirv/ptr-vector-member.slang b/tests/spirv/ptr-vector-member.slang
index 0683d8838..6e493b4c7 100644
--- a/tests/spirv/ptr-vector-member.slang
+++ b/tests/spirv/ptr-vector-member.slang
@@ -1,19 +1,17 @@
-//TEST:SIMPLE(filecheck=CHECK): -target spirv
+//DISABLE_TEST:SIMPLE(filecheck=SPIRV):-stage compute -entry computeMain -target spirv
+//DISABLE_TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly
-// CHECK: %[[PTR:[0-9a-zA-Z_]+]] = OpAccessChain %_ptr_PhysicalStorageBuffer_uint %16 %int_0
-// CHECK: %{{.*}} = OpAtomicIAdd %uint %[[PTR]] %uint_1 %uint_0 %uint_1
+// SPIRV: OpEntryPoint
+// SPIRV-NOT: error
-struct Push2
-{
- uint4 * value;
-};
-
-[[vk::push_constant]] Push2 push2;
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+uniform int4* output;
[shader("compute")]
[numthreads(1, 1, 1)]
-void main()
+void computeMain()
{
- uint * v = &push2.value[0].x;
+ // CHECK: 1
+ int* v = __getAddress(output[0].x);
InterlockedAdd(*v, 1);
} \ No newline at end of file
diff --git a/tools/gfx/gfx.slang b/tools/gfx/gfx.slang
index 8dae5227e..f901fbe28 100644
--- a/tools/gfx/gfx.slang
+++ b/tools/gfx/gfx.slang
@@ -1967,7 +1967,7 @@ SLANG_GFX_IMPORT public bool gfxIsTypelessFormat(Format format);
SLANG_GFX_IMPORT public Result gfxGetFormatInfo(Format format, FormatInfo *outInfo);
/// Given a type returns a function that can conpublic struct it, or nullptr if there isn't one
-SLANG_GFX_IMPORT public Result gfxCreateDevice(const DeviceDesc* desc, out Optional<IDevice> outDevice);
+SLANG_GFX_IMPORT public Result gfxCreateDevice(const Ptr<DeviceDesc> desc, out Optional<IDevice> outDevice);
/// Reports current set of live objects in gfx.
/// Currently this only calls D3D's ReportLiveObjects.