diff options
25 files changed, 683 insertions, 166 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 5c4a35be2..43658563e 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -1284,6 +1284,7 @@ enum AddressSpace : uint64_t { Device = $((uint64_t)AddressSpace::UserPointer), GroupShared = $((uint64_t)AddressSpace::GroupShared), + VaryingInput = $((uint64_t)AddressSpace::Input), }; /// @category misc_types @@ -1328,29 +1329,49 @@ struct Ptr< __intrinsic_op($(kIROp_CastIntToPtr)) __init(int64_t val); - // By default, getter is not an L value + /// Subscript a pointer to get a reference to a value. + /// + /// The pointer must reference a memory location where + /// N >= 0 values of type `T` are stored sequentially, + /// in a layout consistent with an array `T[N]`; + /// otherwise, this operation has undefined behavior. + /// + /// The `index` parameter must satisfy 0 <= `index` <= N; + /// otherwise, this operation has undefined behavior. + /// __generic<TInt : __BuiltinIntegerType> __subscript(TInt index) -> T { - __intrinsic_op($(kIROp_GetOffsetPtr)) - [nonmutating] - ref; - } -}; + // + // TODO(tfoley): The `Ptr` type's subscript operation + // currently provides a `ref` accessor (which returns a + // reference that can be used for reading or writing), + // independent of the `Access` that is specified by the + // generic arguments of a particular pointer type. In + // practice, this means that subscripting a read-only + // pointer yields a readable *and* writable reference. + // + // Previous versions of the code module attempted to + // address the problem by providing a `get` accessor + // on all platforms, and then use an `extension` to + // introduce a `ref` accessor only to pointers that + // are known to be mutable. That approach does not work, + // however, because it is necessary that subscripting + // a pointer yields a reference (l-value), whether or not + // the reference is read-only. + // + // A different solution to the problem is needed, whether + // by splitting up read-only and read-write pointers as + // distinct types (as many languages do), or by allowing + // the subscript operation to explicitly return a reference + // with the same access as the pointer itself. + // -extension<T, AddressSpace addrSpace> Ptr<T, Access::ReadWrite, addrSpace> -{ - // We have a `ref` accessor if we are ReadWrite. This means only `ReadWrite` - // can be used as an L-value. - __generic<TInt : __BuiltinIntegerType> - __subscript(TInt index) -> Ref<T> - { - // If a 'Ptr[index]' is referred to by a '__ref', call 'kIROp_GetOffsetPtr(index)' __intrinsic_op($(kIROp_GetOffsetPtr)) [nonmutating] ref; } -} +}; //@hidden: __intrinsic_op($(kIROp_AlignedAttr)) @@ -1505,28 +1526,33 @@ extension uintptr_t : IRangedValue } //@hidden: -__generic<T> -__magic_type(OutType) + +__magic_type(ExplicitRefType) +__intrinsic_type($(kIROp_RefType)) +struct Ref< + T, + Access access = Access::ReadWrite, + AddressSpace addrSpace = AddressSpace::Device> +{}; + +__magic_type(OutParamType) __intrinsic_type($(kIROp_OutType)) -struct Out +struct OutParam<T> {}; -__generic<T> -__magic_type(InOutType) +__magic_type(InOutParamType) __intrinsic_type($(kIROp_InOutType)) -struct InOut +struct InOutParam<T> {}; -__generic<T> -__magic_type(RefType) +__magic_type(RefParamType) __intrinsic_type($(kIROp_RefType)) -struct Ref +struct RefParam<T> {}; -__generic<T> -__magic_type(ConstRefType) +__magic_type(ConstRefParamType) __intrinsic_type($(kIROp_ConstRefType)) -struct ConstRef +struct ConstRefParam<T> {}; // __Addr<T> is AddressSpace::Generic since Slang will specalize & validate the address-space @@ -2669,15 +2695,9 @@ for (auto op : intrinsicUnaryOps) }}}} // Only ReadWrite is an L-value. -__generic<T, AddressSpace addrSpace> -__intrinsic_op(0) -__prefix Ref<T> operator*(Ptr<T, Access::ReadWrite, addrSpace> value); - -// Unknown access qualifier or Access::Read access qualifier is a promise -// that the pointer is not going to be used as an L-value. -__generic<$(ptrTypeParameterList)> +__generic<T, Access access, AddressSpace addrSpace> __intrinsic_op(0) -__prefix ConstRef<T> operator*($(fullPtrType) value); +__prefix Ref<T, access, addrSpace> operator*(Ptr<T, access, addrSpace> value); // TODO: [require(cpu)]. This cannot be done yet since this change breaks slangpy __generic<T> diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 96adabbde..ff30a921c 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -10354,10 +10354,10 @@ matrix<T, N, M> fwidth_fine(matrix<T, N, M> x) } __intrinsic_op($(kIROp_ResolveVaryingInputRef)) -Ref<T> __ResolveVaryingInputRef<T>(__constref T attribute); +Ref<T, Access.Read, AddressSpace.VaryingInput> __ResolveVaryingInputRef<T>(__constref T attribute); __intrinsic_op($(kIROp_GetPerVertexInputArray)) -Ref<Array<T, 3>> __GetPerVertexInputArray<T>(__constref T attribute); +Ref<Array<T, 3>, Access.Read, AddressSpace.VaryingInput> __GetPerVertexInputArray<T>(__constref T attribute); T __GetAttributeAtVertex<T>(__constref T attribute, uint vertexIndex) { diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index a88db3155..f7304f308 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -512,22 +512,27 @@ Type* ASTBuilder::getScalarLayoutType() // Construct the type `Out<valueType>` OutType* ASTBuilder::getOutType(Type* valueType) { - return dynamicCast<OutType>(getPtrType(valueType, "OutType")); + return dynamicCast<OutType>(getPtrType(valueType, "OutParamType")); } InOutType* ASTBuilder::getInOutType(Type* valueType) { - return dynamicCast<InOutType>(getPtrType(valueType, "InOutType")); + return dynamicCast<InOutType>(getPtrType(valueType, "InOutParamType")); } -RefType* ASTBuilder::getRefType(Type* valueType) +RefParamType* ASTBuilder::getRefParamType(Type* valueType) { - return dynamicCast<RefType>(getPtrType(valueType, "RefType")); + return dynamicCast<RefParamType>(getPtrType(valueType, "RefParamType")); } -ConstRefType* ASTBuilder::getConstRefType(Type* valueType) +ConstRefParamType* ASTBuilder::getConstRefParamType(Type* valueType) { - return dynamicCast<ConstRefType>(getPtrType(valueType, "ConstRefType")); + return dynamicCast<ConstRefParamType>(getPtrType(valueType, "ConstRefParamType")); +} + +ExplicitRefType* ASTBuilder::getExplicitRefType(Type* valueType) +{ + return dynamicCast<ExplicitRefType>(getPtrType(valueType, "ExplicitRefType")); } OptionalType* ASTBuilder::getOptionalType(Type* valueType) diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index 72a3db4ab..9386180c8 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -536,17 +536,23 @@ public: PtrType* getPtrType(Type* valueType, AccessQualifier accessQualifier, AddressSpace addrSpace); PtrType* getPtrType(Type* valueType, Val* accessQualifier, Val* addrSpace); - // Construct the type `Out<valueType>` - OutType* getOutType(Type* valueType); + // Construct the type `OutParam<valueType>` + OutParamType* getOutType(Type* valueType); - // Construct the type `InOut<valueType>` - InOutType* getInOutType(Type* valueType); + // Construct the type `InOutParam<valueType>` + InOutParamType* getInOutType(Type* valueType); + + // Construct the type `RefParam<valueType>` + RefParamType* getRefParamType(Type* valueType); + + // Construct the type `ConstRefParam<valueType>` + ConstRefParamType* getConstRefParamType(Type* valueType); // Construct the type `Ref<valueType>` - RefType* getRefType(Type* valueType); + ExplicitRefType* getExplicitRefType(Type* valueType); - // Construct the type `ConstRef<valueType>` - ConstRefType* getConstRefType(Type* valueType); + // Construct the type `Ref<valueType, .Read>` + ExplicitRefType* getExplicitConstRefType(Type* valueType); // Construct the type `Optional<valueType>` OptionalType* getOptionalType(Type* valueType); diff --git a/source/slang/slang-ast-natural-layout.cpp b/source/slang/slang-ast-natural-layout.cpp index 7643e0433..27f4d43eb 100644 --- a/source/slang/slang-ast-natural-layout.cpp +++ b/source/slang/slang-ast-natural-layout.cpp @@ -174,8 +174,30 @@ NaturalSize ASTNaturalLayoutContext::_calcSizeImpl(Type* type) } else if (auto optionalType = as<OptionalType>(type)) { - if (isNullableType(optionalType->getValueType())) + // Sometimes a type `T` has an unused bit pattern that + // can be used to represent the null/absent optional value, + // and for such types the size of an `Optional<T>` can be + // the same as a `T`, by making use of that unused pattern. + // + if (doesTypeHaveAnUnusedBitPatternThatCanBeUsedForOptionalRepresentation( + optionalType->getValueType())) return calcSize(optionalType->getValueType()); + + // For all other types, an `Optional<T>` is laid out more-or-less + // as a tuple of a `bool` and a `T`. + // + // TODO(tfoley): This appears to be the exact *opposite* of how + // we should be laying out optionals if we want to be at all + // efficient about space. For various targets and layout modes + // (with natural layout currently being one of them), a type + // can have "tail padding," when its size is not a multiple of + // its alignment. In such cases laying things out as `(T, bool)` + // can both end up takign advantage of the tail padding of `T` + // when present *or* for types `T` that don't include tail + // padding in their layout, but have an alignment N > 1 + // the `(T, bool)` order will then *create* N-1 bytes of tail + // padding (that can possibly be exploited elsewhere). + // NaturalSize size = NaturalSize::makeEmpty(); size.append(calcSize(m_astBuilder->getBoolType())); size.append(calcSize(optionalType->getValueType())); diff --git a/source/slang/slang-ast-support-types.cpp b/source/slang/slang-ast-support-types.cpp index 1f5f4f8b2..9d9fdb7da 100644 --- a/source/slang/slang-ast-support-types.cpp +++ b/source/slang/slang-ast-support-types.cpp @@ -11,13 +11,26 @@ namespace Slang QualType::QualType(Type* type) : type(type), isLeftValue(false) { - if (as<RefType>(type)) + if (auto refType = as<ExplicitRefType>(type)) { - isLeftValue = true; - } - else if (as<ConstRefType>(type)) - { - isLeftValue = false; + if (auto optAccessQualifier = refType->tryGetAccessQualifierValue()) + { + auto accessQualifier = *optAccessQualifier; + switch (accessQualifier) + { + case AccessQualifier::ReadWrite: + isLeftValue = true; + break; + + case AccessQualifier::Read: + isLeftValue = false; + break; + + default: + SLANG_UNEXPECTED("unhandled access qualifier"); + break; + } + } } } diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 53f6626d7..1af81a88f 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -450,15 +450,19 @@ Val* PtrTypeBase::getAddressSpace() return _getGenericTypeArg(this, 2); } -AccessQualifier tryGetAccessQualifierValue(Val* val) +std::optional<AccessQualifier> tryGetAccessQualifierValue(Val* val) { - AccessQualifier accessQualifier = AccessQualifier::ReadWrite; - if (auto cintVal = as<ConstantIntVal>(val)) { - accessQualifier = (AccessQualifier)(cintVal->getValue()); + return AccessQualifier(cintVal->getValue()); } - return accessQualifier; + return std::optional<AccessQualifier>(); +} + +std::optional<AccessQualifier> PtrTypeBase::tryGetAccessQualifierValue() +{ + auto accessQualifierArg = this->getAccessQualifier(); + return Slang::tryGetAccessQualifierValue(accessQualifierArg); } AddressSpace tryGetAddressSpaceValue(Val* addrSpaceVal) @@ -517,24 +521,44 @@ void maybePrintAccessQualifierOperand(StringBuilder& out, AccessQualifier access void PtrType::_toTextOverride(StringBuilder& out) { - auto accessQualifier = tryGetAccessQualifierValue(getAccessQualifier()); auto addrSpace = tryGetAddressSpaceValue(getAddressSpace()); out << toSlice("Ptr<") << getValueType(); - maybePrintAccessQualifierOperand(out, accessQualifier); + if (auto optionalAccessQualifier = tryGetAccessQualifierValue()) + maybePrintAccessQualifierOperand(out, *optionalAccessQualifier); maybePrintAddrSpaceOperand(out, addrSpace); out << toSlice(">"); } -void RefType::_toTextOverride(StringBuilder& out) +void ExplicitRefType::_toTextOverride(StringBuilder& out) { - auto accessQualifier = tryGetAccessQualifierValue(getAccessQualifier()); auto addrSpace = tryGetAddressSpaceValue(getAddressSpace()); out << toSlice("Ref<") << getValueType(); - maybePrintAccessQualifierOperand(out, accessQualifier); + if (auto optionalAccessQualifier = tryGetAccessQualifierValue()) + maybePrintAccessQualifierOperand(out, *optionalAccessQualifier); maybePrintAddrSpaceOperand(out, addrSpace); out << toSlice(">"); } +void OutParamType::_toTextOverride(StringBuilder& out) +{ + out << toSlice("out ") << getValueType(); +} + +void InOutParamType::_toTextOverride(StringBuilder& out) +{ + out << toSlice("inout ") << getValueType(); +} + +void RefParamType::_toTextOverride(StringBuilder& out) +{ + out << toSlice("ref ") << getValueType(); +} + +void ConstRefParamType::_toTextOverride(StringBuilder& out) +{ + out << toSlice("borrow ") << getValueType(); +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! NamedExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -556,14 +580,13 @@ Type* NamedExpressionType::_createCanonicalTypeOverride() // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -ParameterDirection FuncType::getParamDirection(Index index) +ParameterDirection getParamPassingModeFromPossiblyWrappedParamType(Type* paramType) { - auto paramType = getParamType(index); - if (as<RefType>(paramType)) + if (as<RefParamType>(paramType)) { return kParameterDirection_Ref; } - else if (as<ConstRefType>(paramType)) + else if (as<ConstRefParamType>(paramType)) { return kParameterDirection_ConstRef; } @@ -581,6 +604,21 @@ ParameterDirection FuncType::getParamDirection(Index index) } } +ParameterDirection FuncType::getParamDirection(Index index) +{ + auto paramType = getParamTypeWithDirectionWrapper(index); + return getParamPassingModeFromPossiblyWrappedParamType(paramType); +} + +Type* FuncType::getParamValueType(Index index) +{ + auto paramType = getParamTypeWithDirectionWrapper(index); + if (auto wrappedParamType = as<ParamDirectionType>(paramType)) + return wrappedParamType->getValueType(); + return paramType; +} + + void FuncType::_toTextOverride(StringBuilder& out) { Index paramCount = getParamCount(); @@ -591,7 +629,7 @@ void FuncType::_toTextOverride(StringBuilder& out) { out << toSlice(", "); } - out << getParamType(pp); + out << getParamTypeWithDirectionWrapper(pp); } out << ") -> " << getResultType(); @@ -615,7 +653,8 @@ Val* FuncType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet s List<Type*> substParamTypes; for (Index pp = 0; pp < getParamCount(); pp++) { - auto substParamType = as<Type>(getParamType(pp)->substituteImpl(astBuilder, subst, &diff)); + auto substParamType = as<Type>( + getParamTypeWithDirectionWrapper(pp)->substituteImpl(astBuilder, subst, &diff)); if (auto typePack = as<ConcreteTypePack>(substParamType)) { // Unwrap the ConcreteTypePack and add each element as a parameter @@ -650,7 +689,7 @@ Type* FuncType::_createCanonicalTypeOverride() List<Type*> canParamTypes; for (Index pp = 0; pp < getParamCount(); pp++) { - canParamTypes.add(getParamType(pp)->getCanonicalType()); + canParamTypes.add(getParamTypeWithDirectionWrapper(pp)->getCanonicalType()); } FuncType* canType = getCurrentASTBuilder()->getFuncType( diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 4994328b2..26d86fc7f 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -687,6 +687,8 @@ class PtrTypeBase : public BuiltinType Type* getValueType(); Val* getAccessQualifier(); Val* getAddressSpace(); + + std::optional<AccessQualifier> tryGetAccessQualifierValue(); }; FIDDLE() @@ -709,7 +711,14 @@ class PtrType : public PtrTypeBase void _toTextOverride(StringBuilder& out); }; -/// A pointer-like type used to represent a parameter "direction" +/// A pointer-like type used to represent a parameter-passing mode. +/// +/// Historically the codebase has referredd to different parameter-passing +/// modes as parameter "directions," because they initially included +/// only `in`, `out`, and `inout`. The name is confusing when applied +/// to things like `ref` parameters, but we haven't had time to rename +/// everything yet. +/// FIDDLE() class ParamDirectionType : public PtrTypeBase { @@ -720,44 +729,68 @@ class ParamDirectionType : public PtrTypeBase // logical pointer that is passed for an `out` // or `in out` parameter FIDDLE(abstract) -class OutTypeBase : public ParamDirectionType +class OutParamTypeBase : public ParamDirectionType { FIDDLE(...) }; +using OutTypeBase = OutParamTypeBase; // The type for an `out` parameter, e.g., `out T` FIDDLE() -class OutType : public OutTypeBase +class OutParamType : public OutParamTypeBase { FIDDLE(...) + void _toTextOverride(StringBuilder& out); }; +using OutType = OutParamType; // The type for an `in out` parameter, e.g., `in out T` FIDDLE() -class InOutType : public OutTypeBase +class InOutParamType : public OutParamTypeBase { FIDDLE(...) + void _toTextOverride(StringBuilder& out); }; +using InOutType = InOutParamType; -FIDDLE(abstract) -class RefTypeBase : public ParamDirectionType +// The type for an `ref` parameter, e.g., `ref T` +FIDDLE() +class RefParamType : public ParamDirectionType { FIDDLE(...) + void _toTextOverride(StringBuilder& out); }; -// The type for an `ref` parameter, e.g., `ref T` +/// The type for a `constref` parameter, e.g., `constref T` +/// +/// Note that, despite the modifier currently used to represent +/// this case in code, this is *not* comparable to the `ref` +/// parameter-passing mode, and is instead an input-only +/// equivalent of `inout`. +/// FIDDLE() -class RefType : public RefTypeBase +class ConstRefParamType : public ParamDirectionType { FIDDLE(...) void _toTextOverride(StringBuilder& out); }; -// The type for an `constref` parameter, e.g., `constref T` +/// A reference type that is explicitly named somewhere in code (`Ref<T>`). +/// +/// The explicit reference types are distinct from the +/// parameter-passing mode wrapper types like `RefParamType`. +/// An explicit reference type is a type that code written in +/// Slang is allowed to name (e.g., by having a function that +/// returns a `Ref<T>`), even if those uses may only occur +/// in the core module. In constrast, the parameter-passing +/// mode wrapper types should only ever be used as part of +/// the encoding of a `FuncType`. +/// FIDDLE() -class ConstRefType : public RefTypeBase +class ExplicitRefType : public PtrTypeBase { FIDDLE(...) + void _toTextOverride(StringBuilder& out); }; FIDDLE() @@ -812,12 +845,92 @@ class FuncType : public Type OperandView<Type> getParamTypes() { return OperandView<Type>(this, 0, getOperandCount() - 2); } Index getParamCount() { return m_operands.getCount() - 2; } - Type* getParamType(Index index) { return as<Type>(getOperand(index)); } - Type* getResultType() { return as<Type>(getOperand(m_operands.getCount() - 2)); } - Type* getErrorType() { return as<Type>(getOperand(m_operands.getCount() - 1)); } + /// Get the type of one of the function's parameters, by index. + /// + /// The type returned by this function may include a wrapper + /// type around what the user-perceived type of the parameter + /// is. For example, if a parameter is declared as `out int a` + /// then this function would return a type coresponding to + /// `OutParam<int>`, using the hidden `OutParam<T>` type defined + /// in the core module. + /// + /// Any code that calls this function should be conscious of + /// the possibility of encountering these wrappers, and handle + /// them accordingly. + /// + Type* getParamTypeWithDirectionWrapper(Index index) { return as<Type>(getOperand(index)); } + + /// Get the type of one of the function's parameters, by index. + /// + /// The type returned by this funciton is the user-perceived + /// type of the parameter, and does not include any wrappers + /// that are introduced to indicate the parameter-passing mode. + /// For example, a parameter declared as `out int a` will simply + /// return the `int` type, the same as would be returned for + /// a parameter simply declared as `int a`. + /// + /// Any code that calls this function should be conscious of + /// the possibility that the type returned may not fully + /// describe the contract for the given parameter, and should + /// make sure to consult `getParamDirection` as well, to get + /// a complete picture. + /// + Type* getParamValueType(Index index); + + /// Get the parameter-passing mode of one of the function's parameters, by index. + /// ParameterDirection getParamDirection(Index index); + /// Combined information on the type and parameter-passing mode of a parameter. + /// + struct ParamInfo + { + /// The parameter-passing mode used for the parameter. + ParameterDirection direction = kParameterDirection_In; + + /// The user-perceived type of the parameter. + Type* type = nullptr; + }; + + /// Get combined information on the type and parameter-passing mode of a parameter. + /// + ParamInfo getParamInfo(Index index) + { + ParamInfo info; + info.direction = getParamDirection(index); + info.type = getParamValueType(index); + return info; + } + + /// Get the result type of this function. + /// + /// This is the type that a call to the function evaluates to if + /// the function returns successfully. + /// + /// A function that conceptually returns no value will have the `Unit` + /// type as its result type. + /// + /// A type that can never return will have the bottom type `Never` + /// as its result type. + /// + Type* getResultType() { return as<Type>(getOperand(m_operands.getCount() - 2)); } + + /// Get the type of errors (if any) that this function can fail with. + /// + /// Evaluation of a call to a function with this `FuncType` may fail + /// with an error of the corresponding error type. + /// + /// A function that cannot fail with an error will have the bottom + /// type `Never` as its error type. + /// + /// Note that a function that "never fails" at the type system level + /// may still fail in various ways that are perceivable to the user. + /// The error type of a function only refers to failure modes that + /// are being explicitly modeled using the Slang type system. + /// + Type* getErrorType() { return as<Type>(getOperand(m_operands.getCount() - 1)); } + // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Type* _createCanonicalTypeOverride(); diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 9355834b5..0a96eaaef 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -1110,8 +1110,8 @@ bool SemanticsVisitor::TryUnifyTypesByStructuralMatch( if (!TryUnifyTypes( constraints, unifyCtx, - fstFunType->getParamType(i), - sndFunType->getParamType(i))) + fstFunType->getParamTypeWithDirectionWrapper(i), + sndFunType->getParamTypeWithDirectionWrapper(i))) return false; } return TryUnifyTypes( diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index d8643c63f..1dcd0d737 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -1465,7 +1465,7 @@ bool SemanticsVisitor::_coerce( } } - // matrix type with different layouts are convertible + // matrix types with different layouts are convertible if (auto fromMatrixType = as<MatrixExpressionType>(fromType)) { if (auto toMatrixType = as<MatrixExpressionType>(toType)) @@ -1491,6 +1491,14 @@ bool SemanticsVisitor::_coerce( } } + // We allow a value of a `struct` type to be coerced to a function + // type if the `struct` provides an appropriate method for calling + // instances of that type. + // + // TODO(tfoley): This can and should be opened up to work for any + // type (or at least any nominal type) that supports the required + // operation. + // if (auto toFuncType = as<FuncType>(toType)) { if (auto fromLambdaType = isDeclRefTypeOf<StructDecl>(fromType)) @@ -1541,6 +1549,16 @@ bool SemanticsVisitor::_coerce( // Is toType and fromType the same via some type equality witness? // If so there is no need to do any conversion. // + // Note that this is a somewhat messy case to have, since we *already* + // have a check for type equality above this point. For this code to + // execute we would need to have a case where the `To` and `From` types + // are considered distinct by `Type::equals` but `tryGetSubtypeWitness` + // is still able to produce a witness for the equality of the two types. + // + // TODO(tfoley): Try to set things up so that we can have an invariant + // that two types count as equal for `Type::equals` if and only if a + // type equality witness for those types can be dervied. + // if (isTypeEqualityWitness(fromIsToWitness)) { if (outToExpr) @@ -1562,29 +1580,54 @@ bool SemanticsVisitor::_coerce( return _failedCoercion(toType, outToExpr, fromExpr, sink); } - // We allow implicit conversion of a parameter group type like - // `ConstantBuffer<X>` or `ParameterBlock<X>` to its element - // type `X`. + // If the type that we are converting from is a parameter group type + // (something like `ConstantBuffer<X>` or `ParameterBlock<X>`) and we + // are converting to some type `Y`, then we want to allow for a multi-step + // conversion where we first implicitly dereference the parameter group + // to get an `X`, and then convert the resulting `X` to a `Y`. + // + // An important special case of the above is when `X == Y`, in which + // case we are just converting, e.g., a `ConstantBuffer<X>` to an `X`. + // + // TODO(tfoley): When this conditional detects a parameter group type + // it funnels the coercion logic into only considering conversions that + // involve an automatic dereference. We need to ensure that any other + // kinds of conversion that could apply to a parameter group are considered + // earlier in this function, or else they will never actually be considered. + // Notably, with this logic in place it is impossible for there to be any + // conversion operations from a parameter-group type defined in code + // (e.g., a constructor for a `DescriptorHandle`-like type that takes + // a `ConstantBufer<T>` parameter will never be considered as part of conversion + // logic, because we will first extract the `T` and then try to convert *that*). // if (auto fromParameterGroupType = as<ParameterGroupType>(fromType)) { auto fromElementType = fromParameterGroupType->getElementType(); - // If we convert, e.g., `ConstantBuffer<A> to `A`, we will allow - // subsequent conversion of `A` to `B` if such a conversion - // is possible. - // - ConversionCost subCost = kConversionCost_None; - DerefExpr* derefExpr = nullptr; if (outToExpr) { + // TODO(tfoley): The logic here effectively assumes that any + // parameter-group type is read-only, because we are not + // setting the `isLeftValue` flag of the `QualType` based + // on the type of the container. That is, a `StorageBuffer<X>` + // and a `ConstantBuffer<X>` would both derive the `QualType` + // of the dereferenced expression from `X` alone, and ignore + // that one of these should yield an l-value and the other + // shouldn't. + // + // In practice, we should have a centralized function that + // handles dereferenencing of any `Expr`, and computes the + // correct type for the result, so that the logic here can + // exactly mirror other cases of implicit dereference. + // derefExpr = m_astBuilder->create<DerefExpr>(); derefExpr->base = fromExpr; derefExpr->type = QualType(fromElementType); derefExpr->checked = true; } + ConversionCost subCost = kConversionCost_None; if (!_coerce(site, toType, outToExpr, fromElementType, derefExpr, sink, &subCost)) { return false; @@ -1595,40 +1638,74 @@ bool SemanticsVisitor::_coerce( return true; } - if (auto refType = as<RefTypeBase>(toType)) - { + // Because (for various bad reasons) we currently support an explicit + // `Ref<T>` type (used to define some of our core-module functions), + // we have to account for the case where an expression of type `T` + // is being coerced to a `Ref<T>`. + // + if (auto refType = as<ExplicitRefType>(toType)) + { + // TODO(tfoley): This logic is deeply and fundamentally incorrect. + // It presumes that if an expression of type `T` can coerce to + // type `U` then it can also coerce to a *reference* to `U`. + // That means that because we support, say, implicit coercion of + // an `int` to a `float`, this logic will support implicit coercion + // of an `int` l-value to a `Ref<float>`!!!! + // ConversionCost cost; if (!canCoerce(refType->getValueType(), fromType, fromExpr, &cost)) return false; - if (as<RefType>(toType) && !fromExpr->type.isLeftValue) + + // Depending on whether the result of the coercion would be an l-value + // or not, we may need to restrict the source to be an l-value. + // + // TODO(tfoley): Here we are again hijacking the `QualType` constructor + // to do the direct work. It's still not clear where this logic should + // live. In the longer run, I'm hopeful that we will get rid of + // the explicit `Ref` type entirely (since it was a design mistake to + // begin with), and thus not have to deal with the miserable mess that + // it pushes back on various parts of the compiler. + // + auto qualRefType = QualType(refType); + if (qualRefType.isLeftValue && !fromExpr->type.isLeftValue) + { + // The result type would be an l-value, but the source isn't, + // so there is no way to support the conversion. + // return false; + } + ConversionCost subCost = kConversionCost_GetRef; + if (outCost) + *outCost = subCost; - MakeRefExpr* refExpr = nullptr; if (outToExpr) { - refExpr = m_astBuilder->create<MakeRefExpr>(); + auto refExpr = m_astBuilder->create<MakeRefExpr>(); refExpr->base = fromExpr; - refExpr->type = QualType(refType); - refExpr->type.isLeftValue = false; + refExpr->type = qualRefType; refExpr->checked = true; *outToExpr = refExpr; } - if (outCost) - *outCost = subCost; + return true; } + // TODO(tfoley): I was told that explicit `Ref` types should not + // be seen by most of the compiler because they would be automatically + // eliminated via `maybeOpenRef()` before other code needs to deal + // with them... but that doesn't seem to be the case given how much + // code here in type coercion is having to account for the possibility + // of `Ref` types. - // Allow implicit dereferencing a reference type. - if (auto fromRefType = as<RefTypeBase>(fromType)) + // If we find ourselves in a situation where we need to coerce an + // expression of type `Ref<T>`, we will first unwrap the reference + // to get an expression of type `T` and then coerce *that*. + // + if (auto fromRefType = as<ExplicitRefType>(fromType)) { auto fromValueType = fromRefType->getValueType(); - // If we convert, e.g., `ConstantBuffer<A> to `A`, we will allow - // subsequent conversion of `A` to `B` if such a conversion - // is possible. - // ConversionCost subCost = kConversionCost_None; Expr* openRefExpr = nullptr; @@ -1642,6 +1719,26 @@ bool SemanticsVisitor::_coerce( return false; } + // + // TODO(tfoley): This logic treats the implicit dereferencing + // of a `Ref<T>` as an additional conversion cost, so that + // a function with an explicit `Ref<T>` parameter would end up + // being preferred over one with just a `T`. + // + // Making that distinction and introducing this cost seems to have + // very little benefit, and risks causing developer confusion, + // because for the most part references are invisible to the user + // (intentionally). + // + // We don't want to support explicit `Ref<T>` types in parameter + // positions anyway (people can use either a `ref` parameter or + // an explicit `Ptr<T>`), so the whole thing is moot. + // + // For that matter, we probably should just remove explicit + // `Ref<T>` types from the language, since they were never + // intended to be there in the first place. + // + if (outCost) *outCost = subCost + kConversionCost_ImplicitDereference; return true; @@ -2007,7 +2104,7 @@ bool SemanticsVisitor::tryCoerceLambdaToFuncType( for (auto param : invokeFunc->getParameters()) { auto paramType = getParamTypeWithDirectionWrapper(m_astBuilder, param); - auto toParamType = toFuncType->getParamType(paramId); + auto toParamType = toFuncType->getParamTypeWithDirectionWrapper(paramId); if (!paramType->equals(toParamType)) { return false; @@ -2189,6 +2286,13 @@ Expr* SemanticsVisitor::coerce( // clobber the type on `fromExpr`, and an invariant here is that coercion // really shouldn't *change* the expression that is passed in, but should // introduce new AST nodes to coerce its value to a different type... + // + // TODO(tfoley): Based on the comment above it seems like my past self + // wrote this code, but looking at it now, I'm unsure why we want to return + // an expression with an error type when we have the `toType` that is + // expected *right there*. It would be good to investigate whether changing + // this to return an expression of the expected type would Just Work. + // return CreateImplicitCastExpr(m_astBuilder->getErrorType(), fromExpr); } diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 4362f0926..d6b50e999 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1649,7 +1649,7 @@ EnumDecl* isEnumType(Type* type) return nullptr; } -bool isNullableType(Type* type) +bool doesTypeHaveAnUnusedBitPatternThatCanBeUsedForOptionalRepresentation(Type* type) { if (as<PtrTypeBase>(type)) return true; @@ -1659,7 +1659,9 @@ bool isNullableType(Type* type) return true; if (as<OptionalType>(type)) return true; - if (as<RefTypeBase>(type)) + // TODO(tfoley): Somebody put the explicit `Ref<T>` types + // here as a list of a nullable type, and it is + if (as<ExplicitRefType>(type)) return true; if (as<NativeStringType>(type)) return true; @@ -10448,12 +10450,12 @@ void SemanticsDeclHeaderVisitor::setFuncTypeIntoRequirementDecl( decl->errorType.type = funcType->getErrorType(); for (Index i = 0; i < funcType->getParamCount(); i++) { - auto paramType = funcType->getParamType(i); - if (auto dirType = as<ParamDirectionType>(paramType)) - paramType = dirType->getValueType(); + auto paramInfo = funcType->getParamInfo(i); + auto paramType = paramInfo.type; + auto paramDir = paramInfo.direction; + auto param = m_astBuilder->create<ParamDecl>(); param->type.type = paramType; - auto paramDir = funcType->getParamDirection(i); switch (paramDir) { case ParameterDirection::kParameterDirection_InOut: @@ -13051,7 +13053,7 @@ void checkDerivativeAttributeImpl( Diagnostics::customDerivativeSignatureMismatchAtPosition, ii, qualTypeToString(argList[ii]->type), - funcType->getParamType(ii)->toString()); + funcType->getParamTypeWithDirectionWrapper(ii)->toString()); } } // The `imaginaryArguments` list does not include the `this` parameter. diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 8994eb783..235b57ca6 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -223,12 +223,26 @@ Expr* SemanticsVisitor::maybeOpenRef(Expr* expr) { auto exprType = expr->type.type; - if (auto refType = as<RefTypeBase>(exprType)) + if (auto refType = as<ExplicitRefType>(exprType)) { auto openRef = m_astBuilder->create<OpenRefExpr>(); openRef->innerExpr = expr; - openRef->type.isLeftValue = (as<RefType>(exprType) != nullptr); + + // TODO(tfoley): The `QualType` constructor has its own + // logic to determine the value category (e.g., whether + // or not something is an l-value) when it is passed + // a `Ref` type. It is unclear whether both this code + // *and* that code are required, or if we can consolidate + // the two. + // + // Note that here we change the actual `Type*` stored in + // the `QualType` to be the underlying value type of the + // reference, whereas the `QualType` constructor does not + // perform such unwrapping. + // + openRef->type = QualType(refType); openRef->type.type = refType->getValueType(); + openRef->checked = true; openRef->loc = expr->loc; return openRef; @@ -538,10 +552,26 @@ Expr* SemanticsVisitor::constructDerefExpr(Expr* base, QualType elementType, Sou derefExpr->type = QualType(elementType); derefExpr->checked = true; - if (as<PtrType>(base->type) || as<RefType>(base->type)) + if (as<PtrType>(base->type)) { + // TODO(tfoley): It is not clear why this is being unconditionally + // set to `true` when the `Ptr` types in the core module has an + // `AccessQualifier` parameter that can be used to form a read-only pointer. + // derefExpr->type.isLeftValue = true; } + else if (as<ExplicitRefType>(base->type)) + { + // TODO(tfoley): The code here is exploiting the ability of the + // `QualType` constructor to compute the correct value category + // for a reference type, so that we don't have to repeat that logic + // here. That might not be the right place for that logic to live, + // however, and so the code here might need updating sooner or + // later. + // + bool baseIsLVal = QualType(base->type.type).isLeftValue; + derefExpr->type.isLeftValue = baseIsLVal; + } else if (isImmutableBufferType(base->type)) { derefExpr->type.isLeftValue = false; @@ -2925,7 +2955,7 @@ Expr* SemanticsVisitor::CheckInvokeExprWithCheckedOperands(InvokeExpr* expr) Index paramCount = funcType->getParamCount(); for (Index pp = 0; pp < paramCount; ++pp) { - auto paramType = funcType->getParamType(pp); + auto paramType = funcType->getParamTypeWithDirectionWrapper(pp); Expr* argExpr = nullptr; ParamDecl* paramDecl = nullptr; if (pp < expr->arguments.getCount()) @@ -2936,7 +2966,7 @@ Expr* SemanticsVisitor::CheckInvokeExprWithCheckedOperands(InvokeExpr* expr) } compareMemoryQualifierOfParamToArgument(paramDecl, argExpr); - if (as<OutTypeBase>(paramType) || as<RefType>(paramType)) + if (as<OutTypeBase>(paramType) || as<RefParamType>(paramType)) { // `out`, `inout`, and `ref` parameters currently require // an *exact* match on the type of the argument. @@ -3047,7 +3077,7 @@ Expr* SemanticsVisitor::CheckInvokeExprWithCheckedOperands(InvokeExpr* expr) const DiagnosticInfo* diagnostic = nullptr; // Try and determine reason for failure - if (as<RefType>(paramType)) + if (as<RefParamType>(paramType)) { // Ref types are not allowed to use this mechanism because // it breaks atomics @@ -3537,21 +3567,66 @@ Expr* SemanticsExprVisitor::maybeRegisterLambdaCapture(Expr* exprIn) return resultMemberExpr; } -Type* SemanticsVisitor::_toDifferentialParamType(Type* primalType) +Type* SemanticsVisitor::_toDifferentialParamType(Type* primalParamType) { - // Check for type modifiers like 'out' and 'inout'. We need to differentiate the - // nested type. + // This function is invoked on parameter types that could + // still be wrapped to represent a parameter-passing mode + // like `ref`, `out`, etc. // - if (auto primalOutType = as<OutType>(primalType)) - { - return m_astBuilder->getOutType(_toDifferentialParamType(primalOutType->getValueType())); - } - else if (auto primalInOutType = as<InOutType>(primalType)) + // We need to intercept these cases here, and ensure that + // the wrapper is not exposed to other parts of the front-end + // code, because they only exist to encode the parameter-passing + // mode, and are not a proper part of the Slang type system + // (at least not at this time). + // + if (auto primalParamWrapperType = as<ParamDirectionType>(primalParamType)) { - return m_astBuilder->getInOutType( - _toDifferentialParamType(primalInOutType->getValueType())); + // Some parameter-passing modes do not naturally lend themselves + // to being differentiated - most notably, `ref` parameters. + // We will detect those cases here, and handle them as a parameter + // of a non-differentiable type would be handled. + // + // TODO(tfoley): With the introduction of `IDifferentiablePtrType`, + // it is possible that something like a `ref` parameter could also + // support autodiff, but it is not clear what a correct + // one-size-fits-all behavior should be in that case. + // + if (as<RefParamType>(primalParamType)) + return primalParamWrapperType; + + // Given a primal type that is a wrapper like `Out<T>`, we can + // extract the underlying primal value type `T`, and determine + // what the differential type value type corresponding to `T` + // should be. + // + auto primalValueType = primalParamWrapperType->getValueType(); + auto diffValueType = _toDifferentialParamType(primalValueType); + + // Once we have created the appropriate differential value type, + // we will form the differential parameter type by wrapping + // the differential value type in the same wrapper that had + // been used for the primal type. + // + if (as<OutType>(primalParamWrapperType)) + { + return m_astBuilder->getOutType(diffValueType); + } + else if (as<InOutType>(primalParamWrapperType)) + { + return m_astBuilder->getInOutType(diffValueType); + } + else if (as<ConstRefParamType>(primalParamWrapperType)) + { + return m_astBuilder->getConstRefParamType(diffValueType); + } + else + { + SLANG_UNEXPECTED("unhandled parameter-passing mode"); + UNREACHABLE_RETURN(diffValueType); + } } - return getDifferentialPairType(primalType); + + return getDifferentialPairType(primalParamType); } Type* SemanticsVisitor::getDifferentialPairType(Type* primalType) @@ -3632,7 +3707,8 @@ Type* SemanticsVisitor::getForwardDiffFuncType(FuncType* originalType) for (Index i = 0; i < originalType->getParamCount(); i++) { - if (auto jvpParamType = _toDifferentialParamType(originalType->getParamType(i))) + if (auto jvpParamType = + _toDifferentialParamType(originalType->getParamTypeWithDirectionWrapper(i))) paramTypes.add(jvpParamType); } FuncType* jvpType = @@ -3658,7 +3734,9 @@ Type* SemanticsVisitor::getBackwardDiffFuncType(FuncType* originalType) for (Index i = 0; i < originalType->getParamCount(); i++) { - if (auto outType = as<OutType>(originalType->getParamType(i))) + auto originalParamType = originalType->getParamTypeWithDirectionWrapper(i); + + if (auto outType = as<OutType>(originalParamType)) { auto diffElementType = tryGetDifferentialType(m_astBuilder, outType->getValueType()); if (diffElementType) @@ -3670,7 +3748,7 @@ Type* SemanticsVisitor::getBackwardDiffFuncType(FuncType* originalType) continue; } } - else if (auto derivType = _toDifferentialParamType(originalType->getParamType(i))) + else if (auto derivType = _toDifferentialParamType(originalParamType)) { if (as<DifferentialPairType>(derivType)) { diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index d82bf4427..dd5c816b1 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -3239,7 +3239,7 @@ bool isImmutableBufferType(Type* type); // Check if `type` is nullable. An `Optional<T>` will occupy the same space as `T`, if `T` // is nullable. -bool isNullableType(Type* type); +bool doesTypeHaveAnUnusedBitPatternThatCanBeUsedForOptionalRepresentation(Type* type); EnumDecl* isEnumType(Type* type); diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index e0dbc7e08..2ad31c9d1 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -643,12 +643,35 @@ static QualType getParamQualType(ASTBuilder* astBuilder, DeclRef<ParamDecl> para static QualType getParamQualType(Type* paramType) { + // TODO(tfoley): This function probably shouldn't exist, and instead + // the accessors for the parameters of a `FuncType` should + // directly return a `QualType` for each parameter rather than + // a plain `Type` that potentially includes a wrapping + // `ParamDirectionType`. + // + // In addition, the determination of what value category a reference + // to a parameter should be (and thus what the `QualType` sould be) + // should be driven by computing the `ParameterDirection` first, + // and then using the direction to determine the value category + // (so as to isolate the code that needs to care about the wrapper + // types to just the computation of the dirction). + // + // Note the large amount of duplication between this function and + // the other `getParamQualType()` above. + // + bool isLVal = false; + Type* valueType = paramType; if (auto paramDirType = as<ParamDirectionType>(paramType)) { - if (as<OutTypeBase>(paramDirType) || as<RefType>(paramDirType)) - return QualType(paramDirType->getValueType(), true); + valueType = paramDirType->getValueType(); + if (as<InOutParamType>(paramDirType)) + isLVal = true; + if (as<OutParamType>(paramDirType)) + isLVal = true; + if (as<RefParamType>(paramDirType)) + isLVal = true; } - return paramType; + return QualType(valueType, isLVal); } bool SemanticsVisitor::TryCheckOverloadCandidateTypes( @@ -673,7 +696,7 @@ bool SemanticsVisitor::TryCheckOverloadCandidateTypes( Count paramCount = funcType->getParamCount(); for (Index i = 0; i < paramCount; ++i) { - auto paramType = getParamQualType(funcType->getParamType(i)); + auto paramType = getParamQualType(funcType->getParamTypeWithDirectionWrapper(i)); paramTypes.add(paramType); } } @@ -2666,7 +2689,8 @@ void SemanticsVisitor::AddHigherOrderOverloadCandidates( List<QualType> paramTypes; for (Index ii = 0; ii < diffFuncType->getParamCount(); ii++) - paramTypes.add(getParamQualType(diffFuncType->getParamType(ii))); + paramTypes.add( + getParamQualType(diffFuncType->getParamTypeWithDirectionWrapper(ii))); // Try to infer generic arguments, based on the updated context. OverloadResolveContext subContext = context; diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index 7c9111629..29134131c 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -778,13 +778,13 @@ Type* getParamTypeWithDirectionWrapper(ASTBuilder* astBuilder, DeclRef<VarDeclBa case kParameterDirection_In: return result; case kParameterDirection_ConstRef: - return astBuilder->getConstRefType(result); + return astBuilder->getConstRefParamType(result); case kParameterDirection_Out: return astBuilder->getOutType(result); case kParameterDirection_InOut: return astBuilder->getInOutType(result); case kParameterDirection_Ref: - return astBuilder->getRefType(result); + return astBuilder->getRefParamType(result); default: return result; } diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index 82f9596e6..e3a93bd86 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -27,7 +27,7 @@ Type* getPointedToTypeIfCanImplicitDeref(Type* type) { return ptrType->getValueType(); } - else if (auto refType = as<RefType>(type)) + else if (auto refType = as<ExplicitRefType>(type)) { return refType->getValueType(); } diff --git a/source/slang/slang-ir-fix-entrypoint-callsite.cpp b/source/slang/slang-ir-fix-entrypoint-callsite.cpp index 7390f3a7f..a0ab07928 100644 --- a/source/slang/slang-ir-fix-entrypoint-callsite.cpp +++ b/source/slang/slang-ir-fix-entrypoint-callsite.cpp @@ -63,6 +63,11 @@ void fixEntryPointCallsites(IRFunc* entryPoint) // and the caller is passing a value, we need to wrap the value in a temporary var // and pass the temporary var. // + // TODO(tfoley): Wait, what? The situation this code is trying to fix should + // never be allowed to occur in the first place. This code shouldn't be + // trying to defend against the bad input; instead we should be *fixing* + // the source of the problem. + // auto funcType = as<IRFuncType>(callee->getDataType()); SLANG_ASSERT(funcType); IRBuilder builder(call); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 792848ce6..3cf5d7803 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -2042,7 +2042,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower List<IRType*> paramTypes; for (Index pp = 0; pp < paramCount; ++pp) { - paramTypes.add(lowerType(context, type->getParamType(pp))); + paramTypes.add(lowerType(context, type->getParamTypeWithDirectionWrapper(pp))); } if (type->getErrorType()->equals(context->astBuilder->getBottomType())) { @@ -2820,12 +2820,14 @@ void addArg( // from the arg. paramType = lowerType(context, argType); } +#if 0 if (auto refType = as<IRConstRefType>(paramType)) { paramType = refType->getValueType(); argVal = LoweredValInfo::simple( context->irBuilder->emitLoad(getSimpleVal(context, argPtr))); } +#endif LoweredValInfo tempVar = createVar(context, paramType); @@ -3814,13 +3816,13 @@ struct ExprLoweringContext for (Index i = 0; i < argCount; ++i) { - IRType* paramType = lowerType(context, funcType->getParamType(i)); - ParameterDirection paramDirection = funcType->getParamDirection(i); + auto paramInfo = funcType->getParamInfo(i); + IRType* paramType = lowerType(context, paramInfo.type); addDirectCallArgs( expr, i, paramType, - paramDirection, + paramInfo.direction, DeclRef<ParamDecl>(), ioArgs, ioFixups); @@ -4736,16 +4738,55 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo> if (loweredBase.flavor != LoweredValInfo::Flavor::Ptr) { - SLANG_ASSERT(as<ConstRefType>(expr->type)); - // If the base isn't a pointer, then we are trying to form - // a const ref to a temporary value. - // To do so we must copy it into a variable. + // If the base expression is not one that (trivially) + // lower to a pointer, then we have a bit of a problem, + // because the semantics of forming a reference are + // that we should refer to the memory location of + // the operand itself. + // + // For now, we are hacking this case by supporting + // formation of a *read-only* reference when the base + // expression is an r-value, by first copying the base + // expression into a temporary. + // + // Note that this approach is semantically incorrect, + // and a fix should be made further up the stack to + // rule out whatever is happening here. + // + // TODO(tfoley): Investigate why this case is arising + // at all, and/or eliminate the explicit `Ref` type + // entirely, so we don't have to deal with it. + + + // We start by asserting that the reference type we + // are being asked to form is read-only. + // + SLANG_ASSERT(as<ExplicitRefType>(expr->type) && !QualType(expr->type).isLeftValue); + + // Now we perpetrate our hackery, by forming a simple value + // for the operand in an SSA register and copying it into + // a temporary. + // + // TODO(tfoley): This logic might be better expressed by + // forming a `LoweredValInfo` for the temporary and then + // using the `assign()` operation to write the base into it, + // since that operation might produce simpler code than + // we get by using `getSimpleVal` here. + // auto baseVal = getSimpleVal(context, loweredBase); auto tempVar = context->irBuilder->emitVar(baseVal->getFullType()); context->irBuilder->emitStore(tempVar, baseVal); loweredBase.val = tempVar; } + // Note that the `flavor` of the lowered value that we return + // is always `Simple`, because at the level of the IR a value + // of type `Ref` is just a pointer. + // + // In the case where the hack above was used to introduce a + // temporary, the pointer value is the address of the temporary + // variable itself. + // loweredBase.flavor = LoweredValInfo::Flavor::Simple; return loweredBase; } diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp index 311725e08..d96b5591b 100644 --- a/source/slang/slang-mangle.cpp +++ b/source/slang/slang-mangle.cpp @@ -253,7 +253,7 @@ void emitType(ManglingContext* context, Type* type) auto n = funcType->getParamCount(); emit(context, n); for (Index i = 0; i < n; ++i) - emitType(context, funcType->getParamType(i)); + emitType(context, funcType->getParamTypeWithDirectionWrapper(i)); emitType(context, funcType->getResultType()); emitType(context, funcType->getErrorType()); } diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index d9eb884f0..b991a0caf 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -925,13 +925,18 @@ FuncType* getFuncType(ASTBuilder* astBuilder, DeclRef<CallableDecl> const& declR { paramType = astBuilder->getErrorType(); } + + // TODO(tfoley): This code should first compute the appropriate + // parameter-passing mode ("direction") for the `paramDecl` and + // then use that mode to decide which wrapper type to use. + // if (paramDecl->findModifier<RefModifier>()) { - paramType = astBuilder->getRefType(paramType); + paramType = astBuilder->getRefParamType(paramType); } else if (paramDecl->findModifier<ConstRefModifier>()) { - paramType = astBuilder->getConstRefType(paramType); + paramType = astBuilder->getConstRefParamType(paramType); } else if (paramDecl->findModifier<OutModifier>()) { diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index 519b3ab06..0de6348bf 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -5088,9 +5088,26 @@ static TypeLayoutResult _createTypeLayout(TypeLayoutContext& context, Type* type } else if (auto optionalType = as<OptionalType>(type)) { - // OptionalType should be laid out the same way as Tuple<T, bool>. - if (isNullableType(optionalType->getValueType())) + // Sometimes a type `T` has an unused bit pattern that + // can be used to represent the null/absent optional value, + // and for such types the size of an `Optional<T>` can be + // the same as a `T`, by making use of that unused pattern. + // + if (doesTypeHaveAnUnusedBitPatternThatCanBeUsedForOptionalRepresentation( + optionalType->getValueType())) return _createTypeLayout(context, optionalType->getValueType()); + + // For all other types, an `Optional<T>` is laid out more-or-less + // as tuple of a `T` and a `bool`. + // + // TODO(tfoley): This code implements the `(T,bool)` ordering, + // which provides more easy opportunities to generate compact + // layouts by using "tail padding" than the `(bool, T)` ordering. + // However the "natural layout" implementation does not match + // what is being done here (it uses the `(bool, T)` ordering). + // The discrepancy should probably be fixed, but doing so would + // technically be a breaking change. + // Array<Type*, 2> types = makeArray(optionalType->getValueType(), context.astBuilder->getBoolType()); auto tupleType = context.astBuilder->getTupleType(types.getView()); diff --git a/tests/bugs/array-size-groupshared.slang b/tests/bugs/array-size-groupshared.slang index 1acda0292..88adda28e 100644 --- a/tests/bugs/array-size-groupshared.slang +++ b/tests/bugs/array-size-groupshared.slang @@ -16,6 +16,18 @@ struct GenType<T : __BuiltinIntegerType, A: IA, let N : int, let M : int> { static const int HalfN = N > 1? N / A.M : 1; static const int P = M + N; + + // TODO(tfoley): What this test is testing seems to be outside + // the scope of what we ever intend to support in user code. + // Returning an `Ref<T>` is supposed to be a core-module-only + // thing, and even then is something that we would like to do less + // of over time. + // + // The only purpose of this test *seems* to be ensuring that this + // particular function (and the `groupshared` declaration inside + // it) "works," but the function itself is not something that we + // intend to be supported in Slang. + // [ForceInline] Ref<uint> weights(int index) { diff --git a/tests/bugs/generic-groupshared.slang b/tests/bugs/generic-groupshared.slang index 9208f795a..c52f9be03 100644 --- a/tests/bugs/generic-groupshared.slang +++ b/tests/bugs/generic-groupshared.slang @@ -4,6 +4,17 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer RWStructuredBuffer<int> outputBuffer; +// TODO(tfoley): What this test is testing seems to be outside +// the scope of what we ever intend to support in user code. +// Returning an `Ref<T>` is supposed to be a core-module-only +// thing, and even then is something that we would like to do less +// of over time. +// +// The only purpose of this test *seems* to be ensuring that this +// particular function (and the `groupshared` declaration inside +// it) "works," but the function itself is not something that we +// intend to be supported in Slang. +// [ForceInline] Ref<uint> table<let n: int>(int index) { diff --git a/tests/diagnostics/autodiff-custom-diff-inout.slang.expected b/tests/diagnostics/autodiff-custom-diff-inout.slang.expected index 103c94c9e..2469d7257 100644 --- a/tests/diagnostics/autodiff-custom-diff-inout.slang.expected +++ b/tests/diagnostics/autodiff-custom-diff-inout.slang.expected @@ -3,7 +3,7 @@ standard error = { tests/diagnostics/autodiff-custom-diff-inout.slang(3): error 30019: expected an expression of type 'float', got 'DifferentialPair<float>' [BackwardDerivative(__d_f)] ^~~~~~~~~~~~~~~~~~ -tests/diagnostics/autodiff-custom-diff-inout.slang(3): error 31149: invalid custom derivative. parameter type mismatch at position 0. expected 'InOut<DifferentialPair<float>>', got 'float' +tests/diagnostics/autodiff-custom-diff-inout.slang(3): error 31149: invalid custom derivative. parameter type mismatch at position 0. expected 'inout DifferentialPair<float>', got 'float' [BackwardDerivative(__d_f)] ^~~~~~~~~~~~~~~~~~ } diff --git a/tests/spirv/get-vertex-attribute.slang b/tests/spirv/get-vertex-attribute.slang index 655b7ad03..84807682a 100644 --- a/tests/spirv/get-vertex-attribute.slang +++ b/tests/spirv/get-vertex-attribute.slang @@ -1,5 +1,5 @@ //TEST:SIMPLE(filecheck=CHECK): -target spirv -//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-via-glsl +//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-via-glsl -dump-intermediates // CHECK: OpDecorate %vout_vertexID{{.*}} PerVertexKHR |
