diff options
| author | Theresa Foley <10618364+tangent-vector@users.noreply.github.com> | 2025-09-22 18:20:13 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-09-23 01:20:13 +0000 |
| commit | c35b763f811298a6e9c61a4a8eaf805ea98bd608 (patch) | |
| tree | c981b9b88939b8920ea291c3f4a6ba828535a946 /source | |
| parent | ba8132345cbae5b749b4a01deda732ad6f8251a0 (diff) | |
Split overloaded uses of RefType in front-end (#8427)
Overview
========
This change is the start of an attempt to address how the Slang compiler
codebase has ended up conflating two similar, but semantically distinct,
concepts:
* The long-standing notion of `ref` parameters (only allowed for use in
the builtin modules), which are encoded using a wrapper `Type` in the
AST as part of the representation of the parameters of a `FuncType`.
* A recently-introduced notion of explicit reference types that mirror
the built-in `Ptr` type, with a relationship comparable to that between
pointer and reference types in C++.
The change splits the `Ref<T>` type in the core module into two distinct
types, with one for each of the two use cases. Similarly, the `RefType`
class in the compiler's AST is split into two distinct classes, to
represent the two cases.
Background
==========
The `Ref<T>` type in the core module (hidden and not intended for users
to ever see or use) was originally introduced to encode the `ref`
parameter-passing mode, comparable to the hidden `Out<T>` and `InOut<T>`
types used to encode `out` and `inout` parameter-passing modes. The
`Ref<T>` type in the core module was encoded as a instance of the
`RefType` class in the Slang AST (similar to how `Out<T>` mapped to an
`OutType`). These AST classes were *only* intended to be used by the
compiler front-end as part of its encoding of function types. The
`FuncType` class needed a way to distinguish an `inout int` parameter
from a plain (implicitly `in`) `int` parameter, so these wrapper like
`RefType` and `OutType` were introduced to encode both the parameter
type (`T`) and the parameter-passing mode in a form that could be passed
around as a `Type`.
Notably, the `Ref<T>` type (and `Out<T>`, etc.) were *not* intended to
be type names that ever get uttered in Slang code (not even in the
builtin modules), and the vast majority of the compiler code was not
supposed to ever encounter them. They were an implementation detail of
`FuncType`, and nothing else.
(In hindsight it may have been a mistake to use a nominal type declared
in the core module to implement these wrappers; it might have been a
good idea to use an entirely separate class of `Type` for this case...)
Recent changes to the builtin modules introduced functions that wanted
to *return* a reference (so that the parameter-passing-mode modifiers
like `ref` could not trivially be used), and as part of those changes
the appealingly-named `Ref<T>` type in the core module was re-used for
this new case. Builtin operations were declared with an explicit
`Ref<T>` return type, and parts of the compiler front-end that had
previously been blissfully unaware of the AST's `RefType` (and
`InOutType`, etc.) had to start accounting for the possibility that an
explicit `Ref<T>` would show up.
Related changes also introduced a comparable conflation of the
(unfortunately-named) `constref` parameter-passing modifier and builtin
operations that wanted to return an explicit reference that is
read-only. Both use cases were mapped to the core-module `ConstRef<T>`
type, which appeared in the AST as an instance of the `ConstRefType`
class.
The overlapping use of `ConstRef<T>`` is actually significantly more
troublesome than the `Ref<T>` case because, despite what its name
implies, `constref` was not really supposed to be the read-only analogue
of `ref`, but rather it is closer to the "immutable value borrow"
analogue to `inout`'s "mutable value borrow." The semantics of a "value
borrow" vs. a "memory reference" in Slang have not been very carefully
codified, and the conflation around `ConstRef<T>` has contributed to
things becoming increasingly muddy in the compiler back-end.
Main Changes
============
Core Module
-----------
The `Ref<T>` type has been replaced with two distinct types, with one
for each use case:
* `RefParam<T>` is intended for use when encoding a `ref` parameter in a
function type
* `ExplicitRef<T>` is intended for use when an operation in a builtin
module wants to return a reference
The other types used to represent parameter-passing modes (e.g.,
`InOut<T>`) were renamed to better indicate that their role in defining
parameter types (e.g., `InOutParam<T>`).
The `ExplicitRef<T>` type was given additional generic parameters for
the allowed access and the address space, akin to what `Ptr<T>` now
supports. The pointer dereference operator (prefix `*`) in the core
module should now properly propagate the access and address space of the
pointer over to the reference that gets returned.
The two distinct use cases of `ConstRef<T>` were not split in the way as
`Ref<T>`, instead the case for the `constref` parameter-passing mode
uses `ConstParamRef<T>`, while cases that previously used `ConstRef<T>`
to represent a read-only explicit reference instead now use
`ExplicitRef<T, Access.Read>`.
Prior to this change there were two subscripts declared on pointers: one
in the `Ptr` type itself, and another in an `extension` for pointers
with `Access.ReadWrite`. The comments on the code seemed to indicate
that the catch-all subscript used to only have a `get` accessor, while
the `ref` was only available on read-write pointers, but it seems that
subsequent changes converted the default subscript to support `ref`.
This change eliminates the subscript added via `extension`, since it is
redundant.
AST and Front-End
=================
Similar to the changes in the core module, the AST `RefType` class was
split into:
* `RefParamType` for the case of encoding `ref` parameters
* `ExplicitRefType` for the case where the user meant an explicit
reference type
All the other classes that represent wrappers for encoding
parameter-passing modes (e.g., `OutType`) were similarly renamed (e.g.,
`OutParamType`).
The `ConstRefType` class was simply renamed to `ConstRefParamType`,
because any use cases of `ConstRefType` that intended an explicit
reference type will now use `ExplicitRefType` with `Acccess.Read`.
For convenience, this change includes type aliases to map the old names
for these types over to the new ones (e.g., `using OutType =
OutParamType`) so that the change doesn't need to affect quite so many
lines of code. The `RefType` and `ConstRefType` names are intentionally
left undefined, since it woudl be unsafe to assume that existing use
sites should default to either of the two possible interpretations.
All use cases of `RefType` and `ConstRefType` (and their former shared
base class `RefTypeBase`) were audited and updated to refer to either
`RefParamType`/`ConstRefParamType` or `ExplicitRefType`, as appropriate
(based on whether the context of the code indicated it was working with
parameter-passing mode wrapper types, or explicit reference types).
In many (many) cases comments were added to the code that was updated
(and some unrelated code that needed to be audited along the way) to
note cases where there appears to be something fishy going on in the
compiler and/or there are obvious opportunities for next-step
improvement.
The `QualType` constructor used to infer l-value-ness when passed a
`RefType` or `ConstRefType`; that code was introduced to support
explicit reference types. The code was updated to consult the access
argument of an `ExplicitRefType` to try and determine the right
l-value-ness to use. There is some ambiguity about what should be done
in the case where the value of the generic argument representing the
access cannot be statically determined; a better solution may be needed.
Many other cases in the front-end that were working with `RefType` and
`ConstRefType` for explicit references also need to figure out
l-value-ness, and these were changed to rely on the logic already added
to `QualType` so that it wouldn't have to be duplicated. It isn't clear
if this structure is the best way to tackle the problem, but it seems to
at least be an upgrade over the more strictly ad-hoc logic that was in
place before.
Future Work
===========
IR-Level Work
-------------
The most obvious next step to take is that the split that was made in
the compiler front-end needs to be properly plumbed through all of the
back-end. There appears to be a lot of code in the back end of the
compiler that has made the same conflation of `ref` parameters and
explicit reference types that the front-end did. In practice, any uses
of `ExplicitRef<T>` in the front-end should desugar into plain
pointer-based code in the IR.
Clean Up Parameter-Passing Modes
--------------------------------
The code that handles different parameter-passing modes
(`ParameterDirection`s) and their wrapper types is somewhat scattered
and messy (as found while auditing use cases of `RefType`). A cleanup
pass is warranted to ensure that most code only needs to think about
`ParameterDirection`s. There should ideally be only a single operation
in the front-end that handles determining the `ParameterDirection` of a
parameter based on its modifiers. Similarly, there should be one
operation to wrap a value type based on a parameter direction, and one
operation to derive a `ParameterDirection` from the wrapper type.
Ideally, the accessors for `FuncType` should not provide unrestricted
access to the potentially-wrapped parameter types, and should instead
return some kind of `ParamInfo` struct that encodes both a
`ParameterDirection` and the unwrapped `Type` of the parameter.
Clean Up `QualType`
-------------------
A significant piece of future work that appears required is to
drastically clean up and improve the way that `QualType`s are represente
and handled in the front-end. There are currently various distinct
`bool` flags in `QualType` (some with very unclear meaning) and
differnet parts of the codebase consult/modify only subsets of them; a
clear enumeration of the "value categories" (to use the C++ terminology)
that Slang supports could be quite helpful. Naively, a `QualType` should
at least encode the basic information that a `Ptr` type encodes:
* A value type
* Allowed access (read-only, read-write, etc.)
* Address space
The main additional thing that a `QualType` needs is a way to
distinguish cases where an expression evaluates to:
* A reference to a memory location, where all the information from a
`Ptr` is relevant
* A simple value, such that the access and address space are irrelevant
* A reference to an abstract storage location (a `property`,
`subscript`, or an implicit conversion that needs to support being an
l-value), in which case address space is irrelevant and the "allowed
access" basically amounts to a listing of the accessors the storage
location supports
Eliminate Explicit Reference Types
----------------------------------
Finally, twe should eventually eliminate the `ExplicitRef<T>` type from
the core module (and all of the supporting code from the front-end),
since the feature is not a good fit for the Slang language. We should
find some other way to decorate operations in the builtin module that
need to returns a reference rather than a value (note how `ref`
accessors already avoided exposing explicit reference types, by design).
---------
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/core.meta.slang | 90 | ||||
| -rw-r--r-- | source/slang/hlsl.meta.slang | 4 | ||||
| -rw-r--r-- | source/slang/slang-ast-builder.cpp | 17 | ||||
| -rw-r--r-- | source/slang/slang-ast-builder.h | 20 | ||||
| -rw-r--r-- | source/slang/slang-ast-natural-layout.cpp | 24 | ||||
| -rw-r--r-- | source/slang/slang-ast-support-types.cpp | 25 | ||||
| -rw-r--r-- | source/slang/slang-ast-type.cpp | 73 | ||||
| -rw-r--r-- | source/slang/slang-ast-type.h | 139 | ||||
| -rw-r--r-- | source/slang/slang-check-constraint.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-check-conversion.cpp | 156 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 16 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 118 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 34 | ||||
| -rw-r--r-- | source/slang/slang-check-shader.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-check-type.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-fix-entrypoint-callsite.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 57 | ||||
| -rw-r--r-- | source/slang/slang-mangle.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-syntax.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-type-layout.cpp | 21 |
21 files changed, 658 insertions, 164 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 5c4a35be2..43658563e 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -1284,6 +1284,7 @@ enum AddressSpace : uint64_t { Device = $((uint64_t)AddressSpace::UserPointer), GroupShared = $((uint64_t)AddressSpace::GroupShared), + VaryingInput = $((uint64_t)AddressSpace::Input), }; /// @category misc_types @@ -1328,29 +1329,49 @@ struct Ptr< __intrinsic_op($(kIROp_CastIntToPtr)) __init(int64_t val); - // By default, getter is not an L value + /// Subscript a pointer to get a reference to a value. + /// + /// The pointer must reference a memory location where + /// N >= 0 values of type `T` are stored sequentially, + /// in a layout consistent with an array `T[N]`; + /// otherwise, this operation has undefined behavior. + /// + /// The `index` parameter must satisfy 0 <= `index` <= N; + /// otherwise, this operation has undefined behavior. + /// __generic<TInt : __BuiltinIntegerType> __subscript(TInt index) -> T { - __intrinsic_op($(kIROp_GetOffsetPtr)) - [nonmutating] - ref; - } -}; + // + // TODO(tfoley): The `Ptr` type's subscript operation + // currently provides a `ref` accessor (which returns a + // reference that can be used for reading or writing), + // independent of the `Access` that is specified by the + // generic arguments of a particular pointer type. In + // practice, this means that subscripting a read-only + // pointer yields a readable *and* writable reference. + // + // Previous versions of the code module attempted to + // address the problem by providing a `get` accessor + // on all platforms, and then use an `extension` to + // introduce a `ref` accessor only to pointers that + // are known to be mutable. That approach does not work, + // however, because it is necessary that subscripting + // a pointer yields a reference (l-value), whether or not + // the reference is read-only. + // + // A different solution to the problem is needed, whether + // by splitting up read-only and read-write pointers as + // distinct types (as many languages do), or by allowing + // the subscript operation to explicitly return a reference + // with the same access as the pointer itself. + // -extension<T, AddressSpace addrSpace> Ptr<T, Access::ReadWrite, addrSpace> -{ - // We have a `ref` accessor if we are ReadWrite. This means only `ReadWrite` - // can be used as an L-value. - __generic<TInt : __BuiltinIntegerType> - __subscript(TInt index) -> Ref<T> - { - // If a 'Ptr[index]' is referred to by a '__ref', call 'kIROp_GetOffsetPtr(index)' __intrinsic_op($(kIROp_GetOffsetPtr)) [nonmutating] ref; } -} +}; //@hidden: __intrinsic_op($(kIROp_AlignedAttr)) @@ -1505,28 +1526,33 @@ extension uintptr_t : IRangedValue } //@hidden: -__generic<T> -__magic_type(OutType) + +__magic_type(ExplicitRefType) +__intrinsic_type($(kIROp_RefType)) +struct Ref< + T, + Access access = Access::ReadWrite, + AddressSpace addrSpace = AddressSpace::Device> +{}; + +__magic_type(OutParamType) __intrinsic_type($(kIROp_OutType)) -struct Out +struct OutParam<T> {}; -__generic<T> -__magic_type(InOutType) +__magic_type(InOutParamType) __intrinsic_type($(kIROp_InOutType)) -struct InOut +struct InOutParam<T> {}; -__generic<T> -__magic_type(RefType) +__magic_type(RefParamType) __intrinsic_type($(kIROp_RefType)) -struct Ref +struct RefParam<T> {}; -__generic<T> -__magic_type(ConstRefType) +__magic_type(ConstRefParamType) __intrinsic_type($(kIROp_ConstRefType)) -struct ConstRef +struct ConstRefParam<T> {}; // __Addr<T> is AddressSpace::Generic since Slang will specalize & validate the address-space @@ -2669,15 +2695,9 @@ for (auto op : intrinsicUnaryOps) }}}} // Only ReadWrite is an L-value. -__generic<T, AddressSpace addrSpace> -__intrinsic_op(0) -__prefix Ref<T> operator*(Ptr<T, Access::ReadWrite, addrSpace> value); - -// Unknown access qualifier or Access::Read access qualifier is a promise -// that the pointer is not going to be used as an L-value. -__generic<$(ptrTypeParameterList)> +__generic<T, Access access, AddressSpace addrSpace> __intrinsic_op(0) -__prefix ConstRef<T> operator*($(fullPtrType) value); +__prefix Ref<T, access, addrSpace> operator*(Ptr<T, access, addrSpace> value); // TODO: [require(cpu)]. This cannot be done yet since this change breaks slangpy __generic<T> diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 96adabbde..ff30a921c 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -10354,10 +10354,10 @@ matrix<T, N, M> fwidth_fine(matrix<T, N, M> x) } __intrinsic_op($(kIROp_ResolveVaryingInputRef)) -Ref<T> __ResolveVaryingInputRef<T>(__constref T attribute); +Ref<T, Access.Read, AddressSpace.VaryingInput> __ResolveVaryingInputRef<T>(__constref T attribute); __intrinsic_op($(kIROp_GetPerVertexInputArray)) -Ref<Array<T, 3>> __GetPerVertexInputArray<T>(__constref T attribute); +Ref<Array<T, 3>, Access.Read, AddressSpace.VaryingInput> __GetPerVertexInputArray<T>(__constref T attribute); T __GetAttributeAtVertex<T>(__constref T attribute, uint vertexIndex) { diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index a88db3155..f7304f308 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -512,22 +512,27 @@ Type* ASTBuilder::getScalarLayoutType() // Construct the type `Out<valueType>` OutType* ASTBuilder::getOutType(Type* valueType) { - return dynamicCast<OutType>(getPtrType(valueType, "OutType")); + return dynamicCast<OutType>(getPtrType(valueType, "OutParamType")); } InOutType* ASTBuilder::getInOutType(Type* valueType) { - return dynamicCast<InOutType>(getPtrType(valueType, "InOutType")); + return dynamicCast<InOutType>(getPtrType(valueType, "InOutParamType")); } -RefType* ASTBuilder::getRefType(Type* valueType) +RefParamType* ASTBuilder::getRefParamType(Type* valueType) { - return dynamicCast<RefType>(getPtrType(valueType, "RefType")); + return dynamicCast<RefParamType>(getPtrType(valueType, "RefParamType")); } -ConstRefType* ASTBuilder::getConstRefType(Type* valueType) +ConstRefParamType* ASTBuilder::getConstRefParamType(Type* valueType) { - return dynamicCast<ConstRefType>(getPtrType(valueType, "ConstRefType")); + return dynamicCast<ConstRefParamType>(getPtrType(valueType, "ConstRefParamType")); +} + +ExplicitRefType* ASTBuilder::getExplicitRefType(Type* valueType) +{ + return dynamicCast<ExplicitRefType>(getPtrType(valueType, "ExplicitRefType")); } OptionalType* ASTBuilder::getOptionalType(Type* valueType) diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index 72a3db4ab..9386180c8 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -536,17 +536,23 @@ public: PtrType* getPtrType(Type* valueType, AccessQualifier accessQualifier, AddressSpace addrSpace); PtrType* getPtrType(Type* valueType, Val* accessQualifier, Val* addrSpace); - // Construct the type `Out<valueType>` - OutType* getOutType(Type* valueType); + // Construct the type `OutParam<valueType>` + OutParamType* getOutType(Type* valueType); - // Construct the type `InOut<valueType>` - InOutType* getInOutType(Type* valueType); + // Construct the type `InOutParam<valueType>` + InOutParamType* getInOutType(Type* valueType); + + // Construct the type `RefParam<valueType>` + RefParamType* getRefParamType(Type* valueType); + + // Construct the type `ConstRefParam<valueType>` + ConstRefParamType* getConstRefParamType(Type* valueType); // Construct the type `Ref<valueType>` - RefType* getRefType(Type* valueType); + ExplicitRefType* getExplicitRefType(Type* valueType); - // Construct the type `ConstRef<valueType>` - ConstRefType* getConstRefType(Type* valueType); + // Construct the type `Ref<valueType, .Read>` + ExplicitRefType* getExplicitConstRefType(Type* valueType); // Construct the type `Optional<valueType>` OptionalType* getOptionalType(Type* valueType); diff --git a/source/slang/slang-ast-natural-layout.cpp b/source/slang/slang-ast-natural-layout.cpp index 7643e0433..27f4d43eb 100644 --- a/source/slang/slang-ast-natural-layout.cpp +++ b/source/slang/slang-ast-natural-layout.cpp @@ -174,8 +174,30 @@ NaturalSize ASTNaturalLayoutContext::_calcSizeImpl(Type* type) } else if (auto optionalType = as<OptionalType>(type)) { - if (isNullableType(optionalType->getValueType())) + // Sometimes a type `T` has an unused bit pattern that + // can be used to represent the null/absent optional value, + // and for such types the size of an `Optional<T>` can be + // the same as a `T`, by making use of that unused pattern. + // + if (doesTypeHaveAnUnusedBitPatternThatCanBeUsedForOptionalRepresentation( + optionalType->getValueType())) return calcSize(optionalType->getValueType()); + + // For all other types, an `Optional<T>` is laid out more-or-less + // as a tuple of a `bool` and a `T`. + // + // TODO(tfoley): This appears to be the exact *opposite* of how + // we should be laying out optionals if we want to be at all + // efficient about space. For various targets and layout modes + // (with natural layout currently being one of them), a type + // can have "tail padding," when its size is not a multiple of + // its alignment. In such cases laying things out as `(T, bool)` + // can both end up takign advantage of the tail padding of `T` + // when present *or* for types `T` that don't include tail + // padding in their layout, but have an alignment N > 1 + // the `(T, bool)` order will then *create* N-1 bytes of tail + // padding (that can possibly be exploited elsewhere). + // NaturalSize size = NaturalSize::makeEmpty(); size.append(calcSize(m_astBuilder->getBoolType())); size.append(calcSize(optionalType->getValueType())); diff --git a/source/slang/slang-ast-support-types.cpp b/source/slang/slang-ast-support-types.cpp index 1f5f4f8b2..9d9fdb7da 100644 --- a/source/slang/slang-ast-support-types.cpp +++ b/source/slang/slang-ast-support-types.cpp @@ -11,13 +11,26 @@ namespace Slang QualType::QualType(Type* type) : type(type), isLeftValue(false) { - if (as<RefType>(type)) + if (auto refType = as<ExplicitRefType>(type)) { - isLeftValue = true; - } - else if (as<ConstRefType>(type)) - { - isLeftValue = false; + if (auto optAccessQualifier = refType->tryGetAccessQualifierValue()) + { + auto accessQualifier = *optAccessQualifier; + switch (accessQualifier) + { + case AccessQualifier::ReadWrite: + isLeftValue = true; + break; + + case AccessQualifier::Read: + isLeftValue = false; + break; + + default: + SLANG_UNEXPECTED("unhandled access qualifier"); + break; + } + } } } diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 53f6626d7..1af81a88f 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -450,15 +450,19 @@ Val* PtrTypeBase::getAddressSpace() return _getGenericTypeArg(this, 2); } -AccessQualifier tryGetAccessQualifierValue(Val* val) +std::optional<AccessQualifier> tryGetAccessQualifierValue(Val* val) { - AccessQualifier accessQualifier = AccessQualifier::ReadWrite; - if (auto cintVal = as<ConstantIntVal>(val)) { - accessQualifier = (AccessQualifier)(cintVal->getValue()); + return AccessQualifier(cintVal->getValue()); } - return accessQualifier; + return std::optional<AccessQualifier>(); +} + +std::optional<AccessQualifier> PtrTypeBase::tryGetAccessQualifierValue() +{ + auto accessQualifierArg = this->getAccessQualifier(); + return Slang::tryGetAccessQualifierValue(accessQualifierArg); } AddressSpace tryGetAddressSpaceValue(Val* addrSpaceVal) @@ -517,24 +521,44 @@ void maybePrintAccessQualifierOperand(StringBuilder& out, AccessQualifier access void PtrType::_toTextOverride(StringBuilder& out) { - auto accessQualifier = tryGetAccessQualifierValue(getAccessQualifier()); auto addrSpace = tryGetAddressSpaceValue(getAddressSpace()); out << toSlice("Ptr<") << getValueType(); - maybePrintAccessQualifierOperand(out, accessQualifier); + if (auto optionalAccessQualifier = tryGetAccessQualifierValue()) + maybePrintAccessQualifierOperand(out, *optionalAccessQualifier); maybePrintAddrSpaceOperand(out, addrSpace); out << toSlice(">"); } -void RefType::_toTextOverride(StringBuilder& out) +void ExplicitRefType::_toTextOverride(StringBuilder& out) { - auto accessQualifier = tryGetAccessQualifierValue(getAccessQualifier()); auto addrSpace = tryGetAddressSpaceValue(getAddressSpace()); out << toSlice("Ref<") << getValueType(); - maybePrintAccessQualifierOperand(out, accessQualifier); + if (auto optionalAccessQualifier = tryGetAccessQualifierValue()) + maybePrintAccessQualifierOperand(out, *optionalAccessQualifier); maybePrintAddrSpaceOperand(out, addrSpace); out << toSlice(">"); } +void OutParamType::_toTextOverride(StringBuilder& out) +{ + out << toSlice("out ") << getValueType(); +} + +void InOutParamType::_toTextOverride(StringBuilder& out) +{ + out << toSlice("inout ") << getValueType(); +} + +void RefParamType::_toTextOverride(StringBuilder& out) +{ + out << toSlice("ref ") << getValueType(); +} + +void ConstRefParamType::_toTextOverride(StringBuilder& out) +{ + out << toSlice("borrow ") << getValueType(); +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! NamedExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -556,14 +580,13 @@ Type* NamedExpressionType::_createCanonicalTypeOverride() // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -ParameterDirection FuncType::getParamDirection(Index index) +ParameterDirection getParamPassingModeFromPossiblyWrappedParamType(Type* paramType) { - auto paramType = getParamType(index); - if (as<RefType>(paramType)) + if (as<RefParamType>(paramType)) { return kParameterDirection_Ref; } - else if (as<ConstRefType>(paramType)) + else if (as<ConstRefParamType>(paramType)) { return kParameterDirection_ConstRef; } @@ -581,6 +604,21 @@ ParameterDirection FuncType::getParamDirection(Index index) } } +ParameterDirection FuncType::getParamDirection(Index index) +{ + auto paramType = getParamTypeWithDirectionWrapper(index); + return getParamPassingModeFromPossiblyWrappedParamType(paramType); +} + +Type* FuncType::getParamValueType(Index index) +{ + auto paramType = getParamTypeWithDirectionWrapper(index); + if (auto wrappedParamType = as<ParamDirectionType>(paramType)) + return wrappedParamType->getValueType(); + return paramType; +} + + void FuncType::_toTextOverride(StringBuilder& out) { Index paramCount = getParamCount(); @@ -591,7 +629,7 @@ void FuncType::_toTextOverride(StringBuilder& out) { out << toSlice(", "); } - out << getParamType(pp); + out << getParamTypeWithDirectionWrapper(pp); } out << ") -> " << getResultType(); @@ -615,7 +653,8 @@ Val* FuncType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet s List<Type*> substParamTypes; for (Index pp = 0; pp < getParamCount(); pp++) { - auto substParamType = as<Type>(getParamType(pp)->substituteImpl(astBuilder, subst, &diff)); + auto substParamType = as<Type>( + getParamTypeWithDirectionWrapper(pp)->substituteImpl(astBuilder, subst, &diff)); if (auto typePack = as<ConcreteTypePack>(substParamType)) { // Unwrap the ConcreteTypePack and add each element as a parameter @@ -650,7 +689,7 @@ Type* FuncType::_createCanonicalTypeOverride() List<Type*> canParamTypes; for (Index pp = 0; pp < getParamCount(); pp++) { - canParamTypes.add(getParamType(pp)->getCanonicalType()); + canParamTypes.add(getParamTypeWithDirectionWrapper(pp)->getCanonicalType()); } FuncType* canType = getCurrentASTBuilder()->getFuncType( diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 4994328b2..26d86fc7f 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -687,6 +687,8 @@ class PtrTypeBase : public BuiltinType Type* getValueType(); Val* getAccessQualifier(); Val* getAddressSpace(); + + std::optional<AccessQualifier> tryGetAccessQualifierValue(); }; FIDDLE() @@ -709,7 +711,14 @@ class PtrType : public PtrTypeBase void _toTextOverride(StringBuilder& out); }; -/// A pointer-like type used to represent a parameter "direction" +/// A pointer-like type used to represent a parameter-passing mode. +/// +/// Historically the codebase has referredd to different parameter-passing +/// modes as parameter "directions," because they initially included +/// only `in`, `out`, and `inout`. The name is confusing when applied +/// to things like `ref` parameters, but we haven't had time to rename +/// everything yet. +/// FIDDLE() class ParamDirectionType : public PtrTypeBase { @@ -720,44 +729,68 @@ class ParamDirectionType : public PtrTypeBase // logical pointer that is passed for an `out` // or `in out` parameter FIDDLE(abstract) -class OutTypeBase : public ParamDirectionType +class OutParamTypeBase : public ParamDirectionType { FIDDLE(...) }; +using OutTypeBase = OutParamTypeBase; // The type for an `out` parameter, e.g., `out T` FIDDLE() -class OutType : public OutTypeBase +class OutParamType : public OutParamTypeBase { FIDDLE(...) + void _toTextOverride(StringBuilder& out); }; +using OutType = OutParamType; // The type for an `in out` parameter, e.g., `in out T` FIDDLE() -class InOutType : public OutTypeBase +class InOutParamType : public OutParamTypeBase { FIDDLE(...) + void _toTextOverride(StringBuilder& out); }; +using InOutType = InOutParamType; -FIDDLE(abstract) -class RefTypeBase : public ParamDirectionType +// The type for an `ref` parameter, e.g., `ref T` +FIDDLE() +class RefParamType : public ParamDirectionType { FIDDLE(...) + void _toTextOverride(StringBuilder& out); }; -// The type for an `ref` parameter, e.g., `ref T` +/// The type for a `constref` parameter, e.g., `constref T` +/// +/// Note that, despite the modifier currently used to represent +/// this case in code, this is *not* comparable to the `ref` +/// parameter-passing mode, and is instead an input-only +/// equivalent of `inout`. +/// FIDDLE() -class RefType : public RefTypeBase +class ConstRefParamType : public ParamDirectionType { FIDDLE(...) void _toTextOverride(StringBuilder& out); }; -// The type for an `constref` parameter, e.g., `constref T` +/// A reference type that is explicitly named somewhere in code (`Ref<T>`). +/// +/// The explicit reference types are distinct from the +/// parameter-passing mode wrapper types like `RefParamType`. +/// An explicit reference type is a type that code written in +/// Slang is allowed to name (e.g., by having a function that +/// returns a `Ref<T>`), even if those uses may only occur +/// in the core module. In constrast, the parameter-passing +/// mode wrapper types should only ever be used as part of +/// the encoding of a `FuncType`. +/// FIDDLE() -class ConstRefType : public RefTypeBase +class ExplicitRefType : public PtrTypeBase { FIDDLE(...) + void _toTextOverride(StringBuilder& out); }; FIDDLE() @@ -812,12 +845,92 @@ class FuncType : public Type OperandView<Type> getParamTypes() { return OperandView<Type>(this, 0, getOperandCount() - 2); } Index getParamCount() { return m_operands.getCount() - 2; } - Type* getParamType(Index index) { return as<Type>(getOperand(index)); } - Type* getResultType() { return as<Type>(getOperand(m_operands.getCount() - 2)); } - Type* getErrorType() { return as<Type>(getOperand(m_operands.getCount() - 1)); } + /// Get the type of one of the function's parameters, by index. + /// + /// The type returned by this function may include a wrapper + /// type around what the user-perceived type of the parameter + /// is. For example, if a parameter is declared as `out int a` + /// then this function would return a type coresponding to + /// `OutParam<int>`, using the hidden `OutParam<T>` type defined + /// in the core module. + /// + /// Any code that calls this function should be conscious of + /// the possibility of encountering these wrappers, and handle + /// them accordingly. + /// + Type* getParamTypeWithDirectionWrapper(Index index) { return as<Type>(getOperand(index)); } + + /// Get the type of one of the function's parameters, by index. + /// + /// The type returned by this funciton is the user-perceived + /// type of the parameter, and does not include any wrappers + /// that are introduced to indicate the parameter-passing mode. + /// For example, a parameter declared as `out int a` will simply + /// return the `int` type, the same as would be returned for + /// a parameter simply declared as `int a`. + /// + /// Any code that calls this function should be conscious of + /// the possibility that the type returned may not fully + /// describe the contract for the given parameter, and should + /// make sure to consult `getParamDirection` as well, to get + /// a complete picture. + /// + Type* getParamValueType(Index index); + + /// Get the parameter-passing mode of one of the function's parameters, by index. + /// ParameterDirection getParamDirection(Index index); + /// Combined information on the type and parameter-passing mode of a parameter. + /// + struct ParamInfo + { + /// The parameter-passing mode used for the parameter. + ParameterDirection direction = kParameterDirection_In; + + /// The user-perceived type of the parameter. + Type* type = nullptr; + }; + + /// Get combined information on the type and parameter-passing mode of a parameter. + /// + ParamInfo getParamInfo(Index index) + { + ParamInfo info; + info.direction = getParamDirection(index); + info.type = getParamValueType(index); + return info; + } + + /// Get the result type of this function. + /// + /// This is the type that a call to the function evaluates to if + /// the function returns successfully. + /// + /// A function that conceptually returns no value will have the `Unit` + /// type as its result type. + /// + /// A type that can never return will have the bottom type `Never` + /// as its result type. + /// + Type* getResultType() { return as<Type>(getOperand(m_operands.getCount() - 2)); } + + /// Get the type of errors (if any) that this function can fail with. + /// + /// Evaluation of a call to a function with this `FuncType` may fail + /// with an error of the corresponding error type. + /// + /// A function that cannot fail with an error will have the bottom + /// type `Never` as its error type. + /// + /// Note that a function that "never fails" at the type system level + /// may still fail in various ways that are perceivable to the user. + /// The error type of a function only refers to failure modes that + /// are being explicitly modeled using the Slang type system. + /// + Type* getErrorType() { return as<Type>(getOperand(m_operands.getCount() - 1)); } + // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Type* _createCanonicalTypeOverride(); diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 9355834b5..0a96eaaef 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -1110,8 +1110,8 @@ bool SemanticsVisitor::TryUnifyTypesByStructuralMatch( if (!TryUnifyTypes( constraints, unifyCtx, - fstFunType->getParamType(i), - sndFunType->getParamType(i))) + fstFunType->getParamTypeWithDirectionWrapper(i), + sndFunType->getParamTypeWithDirectionWrapper(i))) return false; } return TryUnifyTypes( diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index d8643c63f..1dcd0d737 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -1465,7 +1465,7 @@ bool SemanticsVisitor::_coerce( } } - // matrix type with different layouts are convertible + // matrix types with different layouts are convertible if (auto fromMatrixType = as<MatrixExpressionType>(fromType)) { if (auto toMatrixType = as<MatrixExpressionType>(toType)) @@ -1491,6 +1491,14 @@ bool SemanticsVisitor::_coerce( } } + // We allow a value of a `struct` type to be coerced to a function + // type if the `struct` provides an appropriate method for calling + // instances of that type. + // + // TODO(tfoley): This can and should be opened up to work for any + // type (or at least any nominal type) that supports the required + // operation. + // if (auto toFuncType = as<FuncType>(toType)) { if (auto fromLambdaType = isDeclRefTypeOf<StructDecl>(fromType)) @@ -1541,6 +1549,16 @@ bool SemanticsVisitor::_coerce( // Is toType and fromType the same via some type equality witness? // If so there is no need to do any conversion. // + // Note that this is a somewhat messy case to have, since we *already* + // have a check for type equality above this point. For this code to + // execute we would need to have a case where the `To` and `From` types + // are considered distinct by `Type::equals` but `tryGetSubtypeWitness` + // is still able to produce a witness for the equality of the two types. + // + // TODO(tfoley): Try to set things up so that we can have an invariant + // that two types count as equal for `Type::equals` if and only if a + // type equality witness for those types can be dervied. + // if (isTypeEqualityWitness(fromIsToWitness)) { if (outToExpr) @@ -1562,29 +1580,54 @@ bool SemanticsVisitor::_coerce( return _failedCoercion(toType, outToExpr, fromExpr, sink); } - // We allow implicit conversion of a parameter group type like - // `ConstantBuffer<X>` or `ParameterBlock<X>` to its element - // type `X`. + // If the type that we are converting from is a parameter group type + // (something like `ConstantBuffer<X>` or `ParameterBlock<X>`) and we + // are converting to some type `Y`, then we want to allow for a multi-step + // conversion where we first implicitly dereference the parameter group + // to get an `X`, and then convert the resulting `X` to a `Y`. + // + // An important special case of the above is when `X == Y`, in which + // case we are just converting, e.g., a `ConstantBuffer<X>` to an `X`. + // + // TODO(tfoley): When this conditional detects a parameter group type + // it funnels the coercion logic into only considering conversions that + // involve an automatic dereference. We need to ensure that any other + // kinds of conversion that could apply to a parameter group are considered + // earlier in this function, or else they will never actually be considered. + // Notably, with this logic in place it is impossible for there to be any + // conversion operations from a parameter-group type defined in code + // (e.g., a constructor for a `DescriptorHandle`-like type that takes + // a `ConstantBufer<T>` parameter will never be considered as part of conversion + // logic, because we will first extract the `T` and then try to convert *that*). // if (auto fromParameterGroupType = as<ParameterGroupType>(fromType)) { auto fromElementType = fromParameterGroupType->getElementType(); - // If we convert, e.g., `ConstantBuffer<A> to `A`, we will allow - // subsequent conversion of `A` to `B` if such a conversion - // is possible. - // - ConversionCost subCost = kConversionCost_None; - DerefExpr* derefExpr = nullptr; if (outToExpr) { + // TODO(tfoley): The logic here effectively assumes that any + // parameter-group type is read-only, because we are not + // setting the `isLeftValue` flag of the `QualType` based + // on the type of the container. That is, a `StorageBuffer<X>` + // and a `ConstantBuffer<X>` would both derive the `QualType` + // of the dereferenced expression from `X` alone, and ignore + // that one of these should yield an l-value and the other + // shouldn't. + // + // In practice, we should have a centralized function that + // handles dereferenencing of any `Expr`, and computes the + // correct type for the result, so that the logic here can + // exactly mirror other cases of implicit dereference. + // derefExpr = m_astBuilder->create<DerefExpr>(); derefExpr->base = fromExpr; derefExpr->type = QualType(fromElementType); derefExpr->checked = true; } + ConversionCost subCost = kConversionCost_None; if (!_coerce(site, toType, outToExpr, fromElementType, derefExpr, sink, &subCost)) { return false; @@ -1595,40 +1638,74 @@ bool SemanticsVisitor::_coerce( return true; } - if (auto refType = as<RefTypeBase>(toType)) - { + // Because (for various bad reasons) we currently support an explicit + // `Ref<T>` type (used to define some of our core-module functions), + // we have to account for the case where an expression of type `T` + // is being coerced to a `Ref<T>`. + // + if (auto refType = as<ExplicitRefType>(toType)) + { + // TODO(tfoley): This logic is deeply and fundamentally incorrect. + // It presumes that if an expression of type `T` can coerce to + // type `U` then it can also coerce to a *reference* to `U`. + // That means that because we support, say, implicit coercion of + // an `int` to a `float`, this logic will support implicit coercion + // of an `int` l-value to a `Ref<float>`!!!! + // ConversionCost cost; if (!canCoerce(refType->getValueType(), fromType, fromExpr, &cost)) return false; - if (as<RefType>(toType) && !fromExpr->type.isLeftValue) + + // Depending on whether the result of the coercion would be an l-value + // or not, we may need to restrict the source to be an l-value. + // + // TODO(tfoley): Here we are again hijacking the `QualType` constructor + // to do the direct work. It's still not clear where this logic should + // live. In the longer run, I'm hopeful that we will get rid of + // the explicit `Ref` type entirely (since it was a design mistake to + // begin with), and thus not have to deal with the miserable mess that + // it pushes back on various parts of the compiler. + // + auto qualRefType = QualType(refType); + if (qualRefType.isLeftValue && !fromExpr->type.isLeftValue) + { + // The result type would be an l-value, but the source isn't, + // so there is no way to support the conversion. + // return false; + } + ConversionCost subCost = kConversionCost_GetRef; + if (outCost) + *outCost = subCost; - MakeRefExpr* refExpr = nullptr; if (outToExpr) { - refExpr = m_astBuilder->create<MakeRefExpr>(); + auto refExpr = m_astBuilder->create<MakeRefExpr>(); refExpr->base = fromExpr; - refExpr->type = QualType(refType); - refExpr->type.isLeftValue = false; + refExpr->type = qualRefType; refExpr->checked = true; *outToExpr = refExpr; } - if (outCost) - *outCost = subCost; + return true; } + // TODO(tfoley): I was told that explicit `Ref` types should not + // be seen by most of the compiler because they would be automatically + // eliminated via `maybeOpenRef()` before other code needs to deal + // with them... but that doesn't seem to be the case given how much + // code here in type coercion is having to account for the possibility + // of `Ref` types. - // Allow implicit dereferencing a reference type. - if (auto fromRefType = as<RefTypeBase>(fromType)) + // If we find ourselves in a situation where we need to coerce an + // expression of type `Ref<T>`, we will first unwrap the reference + // to get an expression of type `T` and then coerce *that*. + // + if (auto fromRefType = as<ExplicitRefType>(fromType)) { auto fromValueType = fromRefType->getValueType(); - // If we convert, e.g., `ConstantBuffer<A> to `A`, we will allow - // subsequent conversion of `A` to `B` if such a conversion - // is possible. - // ConversionCost subCost = kConversionCost_None; Expr* openRefExpr = nullptr; @@ -1642,6 +1719,26 @@ bool SemanticsVisitor::_coerce( return false; } + // + // TODO(tfoley): This logic treats the implicit dereferencing + // of a `Ref<T>` as an additional conversion cost, so that + // a function with an explicit `Ref<T>` parameter would end up + // being preferred over one with just a `T`. + // + // Making that distinction and introducing this cost seems to have + // very little benefit, and risks causing developer confusion, + // because for the most part references are invisible to the user + // (intentionally). + // + // We don't want to support explicit `Ref<T>` types in parameter + // positions anyway (people can use either a `ref` parameter or + // an explicit `Ptr<T>`), so the whole thing is moot. + // + // For that matter, we probably should just remove explicit + // `Ref<T>` types from the language, since they were never + // intended to be there in the first place. + // + if (outCost) *outCost = subCost + kConversionCost_ImplicitDereference; return true; @@ -2007,7 +2104,7 @@ bool SemanticsVisitor::tryCoerceLambdaToFuncType( for (auto param : invokeFunc->getParameters()) { auto paramType = getParamTypeWithDirectionWrapper(m_astBuilder, param); - auto toParamType = toFuncType->getParamType(paramId); + auto toParamType = toFuncType->getParamTypeWithDirectionWrapper(paramId); if (!paramType->equals(toParamType)) { return false; @@ -2189,6 +2286,13 @@ Expr* SemanticsVisitor::coerce( // clobber the type on `fromExpr`, and an invariant here is that coercion // really shouldn't *change* the expression that is passed in, but should // introduce new AST nodes to coerce its value to a different type... + // + // TODO(tfoley): Based on the comment above it seems like my past self + // wrote this code, but looking at it now, I'm unsure why we want to return + // an expression with an error type when we have the `toType` that is + // expected *right there*. It would be good to investigate whether changing + // this to return an expression of the expected type would Just Work. + // return CreateImplicitCastExpr(m_astBuilder->getErrorType(), fromExpr); } diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 4362f0926..d6b50e999 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1649,7 +1649,7 @@ EnumDecl* isEnumType(Type* type) return nullptr; } -bool isNullableType(Type* type) +bool doesTypeHaveAnUnusedBitPatternThatCanBeUsedForOptionalRepresentation(Type* type) { if (as<PtrTypeBase>(type)) return true; @@ -1659,7 +1659,9 @@ bool isNullableType(Type* type) return true; if (as<OptionalType>(type)) return true; - if (as<RefTypeBase>(type)) + // TODO(tfoley): Somebody put the explicit `Ref<T>` types + // here as a list of a nullable type, and it is + if (as<ExplicitRefType>(type)) return true; if (as<NativeStringType>(type)) return true; @@ -10448,12 +10450,12 @@ void SemanticsDeclHeaderVisitor::setFuncTypeIntoRequirementDecl( decl->errorType.type = funcType->getErrorType(); for (Index i = 0; i < funcType->getParamCount(); i++) { - auto paramType = funcType->getParamType(i); - if (auto dirType = as<ParamDirectionType>(paramType)) - paramType = dirType->getValueType(); + auto paramInfo = funcType->getParamInfo(i); + auto paramType = paramInfo.type; + auto paramDir = paramInfo.direction; + auto param = m_astBuilder->create<ParamDecl>(); param->type.type = paramType; - auto paramDir = funcType->getParamDirection(i); switch (paramDir) { case ParameterDirection::kParameterDirection_InOut: @@ -13051,7 +13053,7 @@ void checkDerivativeAttributeImpl( Diagnostics::customDerivativeSignatureMismatchAtPosition, ii, qualTypeToString(argList[ii]->type), - funcType->getParamType(ii)->toString()); + funcType->getParamTypeWithDirectionWrapper(ii)->toString()); } } // The `imaginaryArguments` list does not include the `this` parameter. diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 8994eb783..235b57ca6 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -223,12 +223,26 @@ Expr* SemanticsVisitor::maybeOpenRef(Expr* expr) { auto exprType = expr->type.type; - if (auto refType = as<RefTypeBase>(exprType)) + if (auto refType = as<ExplicitRefType>(exprType)) { auto openRef = m_astBuilder->create<OpenRefExpr>(); openRef->innerExpr = expr; - openRef->type.isLeftValue = (as<RefType>(exprType) != nullptr); + + // TODO(tfoley): The `QualType` constructor has its own + // logic to determine the value category (e.g., whether + // or not something is an l-value) when it is passed + // a `Ref` type. It is unclear whether both this code + // *and* that code are required, or if we can consolidate + // the two. + // + // Note that here we change the actual `Type*` stored in + // the `QualType` to be the underlying value type of the + // reference, whereas the `QualType` constructor does not + // perform such unwrapping. + // + openRef->type = QualType(refType); openRef->type.type = refType->getValueType(); + openRef->checked = true; openRef->loc = expr->loc; return openRef; @@ -538,10 +552,26 @@ Expr* SemanticsVisitor::constructDerefExpr(Expr* base, QualType elementType, Sou derefExpr->type = QualType(elementType); derefExpr->checked = true; - if (as<PtrType>(base->type) || as<RefType>(base->type)) + if (as<PtrType>(base->type)) { + // TODO(tfoley): It is not clear why this is being unconditionally + // set to `true` when the `Ptr` types in the core module has an + // `AccessQualifier` parameter that can be used to form a read-only pointer. + // derefExpr->type.isLeftValue = true; } + else if (as<ExplicitRefType>(base->type)) + { + // TODO(tfoley): The code here is exploiting the ability of the + // `QualType` constructor to compute the correct value category + // for a reference type, so that we don't have to repeat that logic + // here. That might not be the right place for that logic to live, + // however, and so the code here might need updating sooner or + // later. + // + bool baseIsLVal = QualType(base->type.type).isLeftValue; + derefExpr->type.isLeftValue = baseIsLVal; + } else if (isImmutableBufferType(base->type)) { derefExpr->type.isLeftValue = false; @@ -2925,7 +2955,7 @@ Expr* SemanticsVisitor::CheckInvokeExprWithCheckedOperands(InvokeExpr* expr) Index paramCount = funcType->getParamCount(); for (Index pp = 0; pp < paramCount; ++pp) { - auto paramType = funcType->getParamType(pp); + auto paramType = funcType->getParamTypeWithDirectionWrapper(pp); Expr* argExpr = nullptr; ParamDecl* paramDecl = nullptr; if (pp < expr->arguments.getCount()) @@ -2936,7 +2966,7 @@ Expr* SemanticsVisitor::CheckInvokeExprWithCheckedOperands(InvokeExpr* expr) } compareMemoryQualifierOfParamToArgument(paramDecl, argExpr); - if (as<OutTypeBase>(paramType) || as<RefType>(paramType)) + if (as<OutTypeBase>(paramType) || as<RefParamType>(paramType)) { // `out`, `inout`, and `ref` parameters currently require // an *exact* match on the type of the argument. @@ -3047,7 +3077,7 @@ Expr* SemanticsVisitor::CheckInvokeExprWithCheckedOperands(InvokeExpr* expr) const DiagnosticInfo* diagnostic = nullptr; // Try and determine reason for failure - if (as<RefType>(paramType)) + if (as<RefParamType>(paramType)) { // Ref types are not allowed to use this mechanism because // it breaks atomics @@ -3537,21 +3567,66 @@ Expr* SemanticsExprVisitor::maybeRegisterLambdaCapture(Expr* exprIn) return resultMemberExpr; } -Type* SemanticsVisitor::_toDifferentialParamType(Type* primalType) +Type* SemanticsVisitor::_toDifferentialParamType(Type* primalParamType) { - // Check for type modifiers like 'out' and 'inout'. We need to differentiate the - // nested type. + // This function is invoked on parameter types that could + // still be wrapped to represent a parameter-passing mode + // like `ref`, `out`, etc. // - if (auto primalOutType = as<OutType>(primalType)) - { - return m_astBuilder->getOutType(_toDifferentialParamType(primalOutType->getValueType())); - } - else if (auto primalInOutType = as<InOutType>(primalType)) + // We need to intercept these cases here, and ensure that + // the wrapper is not exposed to other parts of the front-end + // code, because they only exist to encode the parameter-passing + // mode, and are not a proper part of the Slang type system + // (at least not at this time). + // + if (auto primalParamWrapperType = as<ParamDirectionType>(primalParamType)) { - return m_astBuilder->getInOutType( - _toDifferentialParamType(primalInOutType->getValueType())); + // Some parameter-passing modes do not naturally lend themselves + // to being differentiated - most notably, `ref` parameters. + // We will detect those cases here, and handle them as a parameter + // of a non-differentiable type would be handled. + // + // TODO(tfoley): With the introduction of `IDifferentiablePtrType`, + // it is possible that something like a `ref` parameter could also + // support autodiff, but it is not clear what a correct + // one-size-fits-all behavior should be in that case. + // + if (as<RefParamType>(primalParamType)) + return primalParamWrapperType; + + // Given a primal type that is a wrapper like `Out<T>`, we can + // extract the underlying primal value type `T`, and determine + // what the differential type value type corresponding to `T` + // should be. + // + auto primalValueType = primalParamWrapperType->getValueType(); + auto diffValueType = _toDifferentialParamType(primalValueType); + + // Once we have created the appropriate differential value type, + // we will form the differential parameter type by wrapping + // the differential value type in the same wrapper that had + // been used for the primal type. + // + if (as<OutType>(primalParamWrapperType)) + { + return m_astBuilder->getOutType(diffValueType); + } + else if (as<InOutType>(primalParamWrapperType)) + { + return m_astBuilder->getInOutType(diffValueType); + } + else if (as<ConstRefParamType>(primalParamWrapperType)) + { + return m_astBuilder->getConstRefParamType(diffValueType); + } + else + { + SLANG_UNEXPECTED("unhandled parameter-passing mode"); + UNREACHABLE_RETURN(diffValueType); + } } - return getDifferentialPairType(primalType); + + return getDifferentialPairType(primalParamType); } Type* SemanticsVisitor::getDifferentialPairType(Type* primalType) @@ -3632,7 +3707,8 @@ Type* SemanticsVisitor::getForwardDiffFuncType(FuncType* originalType) for (Index i = 0; i < originalType->getParamCount(); i++) { - if (auto jvpParamType = _toDifferentialParamType(originalType->getParamType(i))) + if (auto jvpParamType = + _toDifferentialParamType(originalType->getParamTypeWithDirectionWrapper(i))) paramTypes.add(jvpParamType); } FuncType* jvpType = @@ -3658,7 +3734,9 @@ Type* SemanticsVisitor::getBackwardDiffFuncType(FuncType* originalType) for (Index i = 0; i < originalType->getParamCount(); i++) { - if (auto outType = as<OutType>(originalType->getParamType(i))) + auto originalParamType = originalType->getParamTypeWithDirectionWrapper(i); + + if (auto outType = as<OutType>(originalParamType)) { auto diffElementType = tryGetDifferentialType(m_astBuilder, outType->getValueType()); if (diffElementType) @@ -3670,7 +3748,7 @@ Type* SemanticsVisitor::getBackwardDiffFuncType(FuncType* originalType) continue; } } - else if (auto derivType = _toDifferentialParamType(originalType->getParamType(i))) + else if (auto derivType = _toDifferentialParamType(originalParamType)) { if (as<DifferentialPairType>(derivType)) { diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index d82bf4427..dd5c816b1 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -3239,7 +3239,7 @@ bool isImmutableBufferType(Type* type); // Check if `type` is nullable. An `Optional<T>` will occupy the same space as `T`, if `T` // is nullable. -bool isNullableType(Type* type); +bool doesTypeHaveAnUnusedBitPatternThatCanBeUsedForOptionalRepresentation(Type* type); EnumDecl* isEnumType(Type* type); diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index e0dbc7e08..2ad31c9d1 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -643,12 +643,35 @@ static QualType getParamQualType(ASTBuilder* astBuilder, DeclRef<ParamDecl> para static QualType getParamQualType(Type* paramType) { + // TODO(tfoley): This function probably shouldn't exist, and instead + // the accessors for the parameters of a `FuncType` should + // directly return a `QualType` for each parameter rather than + // a plain `Type` that potentially includes a wrapping + // `ParamDirectionType`. + // + // In addition, the determination of what value category a reference + // to a parameter should be (and thus what the `QualType` sould be) + // should be driven by computing the `ParameterDirection` first, + // and then using the direction to determine the value category + // (so as to isolate the code that needs to care about the wrapper + // types to just the computation of the dirction). + // + // Note the large amount of duplication between this function and + // the other `getParamQualType()` above. + // + bool isLVal = false; + Type* valueType = paramType; if (auto paramDirType = as<ParamDirectionType>(paramType)) { - if (as<OutTypeBase>(paramDirType) || as<RefType>(paramDirType)) - return QualType(paramDirType->getValueType(), true); + valueType = paramDirType->getValueType(); + if (as<InOutParamType>(paramDirType)) + isLVal = true; + if (as<OutParamType>(paramDirType)) + isLVal = true; + if (as<RefParamType>(paramDirType)) + isLVal = true; } - return paramType; + return QualType(valueType, isLVal); } bool SemanticsVisitor::TryCheckOverloadCandidateTypes( @@ -673,7 +696,7 @@ bool SemanticsVisitor::TryCheckOverloadCandidateTypes( Count paramCount = funcType->getParamCount(); for (Index i = 0; i < paramCount; ++i) { - auto paramType = getParamQualType(funcType->getParamType(i)); + auto paramType = getParamQualType(funcType->getParamTypeWithDirectionWrapper(i)); paramTypes.add(paramType); } } @@ -2666,7 +2689,8 @@ void SemanticsVisitor::AddHigherOrderOverloadCandidates( List<QualType> paramTypes; for (Index ii = 0; ii < diffFuncType->getParamCount(); ii++) - paramTypes.add(getParamQualType(diffFuncType->getParamType(ii))); + paramTypes.add( + getParamQualType(diffFuncType->getParamTypeWithDirectionWrapper(ii))); // Try to infer generic arguments, based on the updated context. OverloadResolveContext subContext = context; diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index 7c9111629..29134131c 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -778,13 +778,13 @@ Type* getParamTypeWithDirectionWrapper(ASTBuilder* astBuilder, DeclRef<VarDeclBa case kParameterDirection_In: return result; case kParameterDirection_ConstRef: - return astBuilder->getConstRefType(result); + return astBuilder->getConstRefParamType(result); case kParameterDirection_Out: return astBuilder->getOutType(result); case kParameterDirection_InOut: return astBuilder->getInOutType(result); case kParameterDirection_Ref: - return astBuilder->getRefType(result); + return astBuilder->getRefParamType(result); default: return result; } diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index 82f9596e6..e3a93bd86 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -27,7 +27,7 @@ Type* getPointedToTypeIfCanImplicitDeref(Type* type) { return ptrType->getValueType(); } - else if (auto refType = as<RefType>(type)) + else if (auto refType = as<ExplicitRefType>(type)) { return refType->getValueType(); } diff --git a/source/slang/slang-ir-fix-entrypoint-callsite.cpp b/source/slang/slang-ir-fix-entrypoint-callsite.cpp index 7390f3a7f..a0ab07928 100644 --- a/source/slang/slang-ir-fix-entrypoint-callsite.cpp +++ b/source/slang/slang-ir-fix-entrypoint-callsite.cpp @@ -63,6 +63,11 @@ void fixEntryPointCallsites(IRFunc* entryPoint) // and the caller is passing a value, we need to wrap the value in a temporary var // and pass the temporary var. // + // TODO(tfoley): Wait, what? The situation this code is trying to fix should + // never be allowed to occur in the first place. This code shouldn't be + // trying to defend against the bad input; instead we should be *fixing* + // the source of the problem. + // auto funcType = as<IRFuncType>(callee->getDataType()); SLANG_ASSERT(funcType); IRBuilder builder(call); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 792848ce6..3cf5d7803 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -2042,7 +2042,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower List<IRType*> paramTypes; for (Index pp = 0; pp < paramCount; ++pp) { - paramTypes.add(lowerType(context, type->getParamType(pp))); + paramTypes.add(lowerType(context, type->getParamTypeWithDirectionWrapper(pp))); } if (type->getErrorType()->equals(context->astBuilder->getBottomType())) { @@ -2820,12 +2820,14 @@ void addArg( // from the arg. paramType = lowerType(context, argType); } +#if 0 if (auto refType = as<IRConstRefType>(paramType)) { paramType = refType->getValueType(); argVal = LoweredValInfo::simple( context->irBuilder->emitLoad(getSimpleVal(context, argPtr))); } +#endif LoweredValInfo tempVar = createVar(context, paramType); @@ -3814,13 +3816,13 @@ struct ExprLoweringContext for (Index i = 0; i < argCount; ++i) { - IRType* paramType = lowerType(context, funcType->getParamType(i)); - ParameterDirection paramDirection = funcType->getParamDirection(i); + auto paramInfo = funcType->getParamInfo(i); + IRType* paramType = lowerType(context, paramInfo.type); addDirectCallArgs( expr, i, paramType, - paramDirection, + paramInfo.direction, DeclRef<ParamDecl>(), ioArgs, ioFixups); @@ -4736,16 +4738,55 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo> if (loweredBase.flavor != LoweredValInfo::Flavor::Ptr) { - SLANG_ASSERT(as<ConstRefType>(expr->type)); - // If the base isn't a pointer, then we are trying to form - // a const ref to a temporary value. - // To do so we must copy it into a variable. + // If the base expression is not one that (trivially) + // lower to a pointer, then we have a bit of a problem, + // because the semantics of forming a reference are + // that we should refer to the memory location of + // the operand itself. + // + // For now, we are hacking this case by supporting + // formation of a *read-only* reference when the base + // expression is an r-value, by first copying the base + // expression into a temporary. + // + // Note that this approach is semantically incorrect, + // and a fix should be made further up the stack to + // rule out whatever is happening here. + // + // TODO(tfoley): Investigate why this case is arising + // at all, and/or eliminate the explicit `Ref` type + // entirely, so we don't have to deal with it. + + + // We start by asserting that the reference type we + // are being asked to form is read-only. + // + SLANG_ASSERT(as<ExplicitRefType>(expr->type) && !QualType(expr->type).isLeftValue); + + // Now we perpetrate our hackery, by forming a simple value + // for the operand in an SSA register and copying it into + // a temporary. + // + // TODO(tfoley): This logic might be better expressed by + // forming a `LoweredValInfo` for the temporary and then + // using the `assign()` operation to write the base into it, + // since that operation might produce simpler code than + // we get by using `getSimpleVal` here. + // auto baseVal = getSimpleVal(context, loweredBase); auto tempVar = context->irBuilder->emitVar(baseVal->getFullType()); context->irBuilder->emitStore(tempVar, baseVal); loweredBase.val = tempVar; } + // Note that the `flavor` of the lowered value that we return + // is always `Simple`, because at the level of the IR a value + // of type `Ref` is just a pointer. + // + // In the case where the hack above was used to introduce a + // temporary, the pointer value is the address of the temporary + // variable itself. + // loweredBase.flavor = LoweredValInfo::Flavor::Simple; return loweredBase; } diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp index 311725e08..d96b5591b 100644 --- a/source/slang/slang-mangle.cpp +++ b/source/slang/slang-mangle.cpp @@ -253,7 +253,7 @@ void emitType(ManglingContext* context, Type* type) auto n = funcType->getParamCount(); emit(context, n); for (Index i = 0; i < n; ++i) - emitType(context, funcType->getParamType(i)); + emitType(context, funcType->getParamTypeWithDirectionWrapper(i)); emitType(context, funcType->getResultType()); emitType(context, funcType->getErrorType()); } diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index d9eb884f0..b991a0caf 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -925,13 +925,18 @@ FuncType* getFuncType(ASTBuilder* astBuilder, DeclRef<CallableDecl> const& declR { paramType = astBuilder->getErrorType(); } + + // TODO(tfoley): This code should first compute the appropriate + // parameter-passing mode ("direction") for the `paramDecl` and + // then use that mode to decide which wrapper type to use. + // if (paramDecl->findModifier<RefModifier>()) { - paramType = astBuilder->getRefType(paramType); + paramType = astBuilder->getRefParamType(paramType); } else if (paramDecl->findModifier<ConstRefModifier>()) { - paramType = astBuilder->getConstRefType(paramType); + paramType = astBuilder->getConstRefParamType(paramType); } else if (paramDecl->findModifier<OutModifier>()) { diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index 519b3ab06..0de6348bf 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -5088,9 +5088,26 @@ static TypeLayoutResult _createTypeLayout(TypeLayoutContext& context, Type* type } else if (auto optionalType = as<OptionalType>(type)) { - // OptionalType should be laid out the same way as Tuple<T, bool>. - if (isNullableType(optionalType->getValueType())) + // Sometimes a type `T` has an unused bit pattern that + // can be used to represent the null/absent optional value, + // and for such types the size of an `Optional<T>` can be + // the same as a `T`, by making use of that unused pattern. + // + if (doesTypeHaveAnUnusedBitPatternThatCanBeUsedForOptionalRepresentation( + optionalType->getValueType())) return _createTypeLayout(context, optionalType->getValueType()); + + // For all other types, an `Optional<T>` is laid out more-or-less + // as tuple of a `T` and a `bool`. + // + // TODO(tfoley): This code implements the `(T,bool)` ordering, + // which provides more easy opportunities to generate compact + // layouts by using "tail padding" than the `(bool, T)` ordering. + // However the "natural layout" implementation does not match + // what is being done here (it uses the `(bool, T)` ordering). + // The discrepancy should probably be fixed, but doing so would + // technically be a breaking change. + // Array<Type*, 2> types = makeArray(optionalType->getValueType(), context.astBuilder->getBoolType()); auto tupleType = context.astBuilder->getTupleType(types.getView()); |
