diff options
| author | Julius Ikkala <julius.ikkala@gmail.com> | 2025-06-28 05:39:24 +0300 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-06-28 02:39:24 +0000 |
| commit | 7349dc5cff49cf22c82eb912813e47f30cd7a757 (patch) | |
| tree | 4d7b3e14f119e7bb48623e52c890b461fd3d9701 | |
| parent | a13dda4f214274a10d39f37c79622fc3e62da310 (diff) | |
Minimal optional constraints (#7422)
* Parse optional witness syntax
* Allow failing optional constraint
* Make `is` work with optional constraint
* Allow using optional constraint in checked if statements
* Fix tests
* Make it work with structs
* Fix MSVC build error
* Disallow using `as` with optional constraints
* Update test to match split is/as errors
* Add tests
* Fix uninitialized variables in tests
* Add tests of incorrect uses & fix related bugs
* Mention optional constraints in docs
* format code
* Fix type unification with NoneWitness
* Fix formatting
---------
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Co-authored-by: Nathan V. Morrical <natemorrical@gmail.com>
24 files changed, 556 insertions, 65 deletions
diff --git a/docs/user-guide/06-interfaces-generics.md b/docs/user-guide/06-interfaces-generics.md index bb19fd776..7c957f685 100644 --- a/docs/user-guide/06-interfaces-generics.md +++ b/docs/user-guide/06-interfaces-generics.md @@ -144,6 +144,28 @@ struct MyType<T, U> } ``` +Optional conformances can be expressed compactly using the `where optional` syntax: +```csharp +// Together, these two overloads... +int myGenericMethod<T>(T arg) +{ +} + +int myGenericMethod<T>(T arg) where T: IFoo +{ + arg.myMethod(1.0); +} + +// ... are equivalent to: +int myGenericMethod<T>(T arg) where optional T: IFoo +{ + if (T is IFoo) + { + arg.myMethod(1.0); // OK in a block that checks for T: IFoo conformance. + } +} +``` + Supported Constructs in Interface Definitions ----------------------------------------------------- diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index ca023a19b..88dea0b7e 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -219,6 +219,13 @@ class IgnoreForLookupModifier : public Modifier FIDDLE(...) }; +/// A modifier that indicates an `TypeConstraintDecl` is optional. +FIDDLE() +class OptionalConstraintModifier : public Modifier +{ + FIDDLE(...) +}; + // A modifier that marks something as an operation that // has a one-to-one translation to the IR, and thus // has no direct definition in the high-level language. diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index e4002c237..572d05d9f 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -123,6 +123,7 @@ FIDDLE() namespace Slang kConversionCost_ValToOptional = 150, kConversionCost_NullPtrToPtr = 150, kConversionCost_PtrToVoidPtr = 150, + kConversionCost_FailedOptionalConstraint = 150, // Conversions that are lossless, but change "kind" kConversionCost_UnsignedToSignedPromotion = 200, diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index 40daae3a6..0e942fe39 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -896,6 +896,18 @@ Val* TypeCoercionWitness::_resolveImplOverride() return this; } +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! NoneWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +void NoneWitness::_toTextOverride(StringBuilder& out) +{ + out.append("none"); +} + +Val* NoneWitness::_resolveImplOverride() +{ + return this; +} + // UNormModifierVal void UNormModifierVal::_toTextOverride(StringBuilder& out) diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index a8b969a94..f371d76fe 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -697,6 +697,13 @@ class DeclaredSubtypeWitness : public SubtypeWitness return false; } + bool isOptional() + { + if (auto declRef = getDeclRef().as<GenericTypeConstraintDecl>()) + return declRef.getDecl()->hasModifier<OptionalConstraintModifier>(); + return false; + } + // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); @@ -832,6 +839,16 @@ class ExtractFromConjunctionSubtypeWitness : public SubtypeWitness ConversionCost _getOverloadResolutionCostOverride(); }; +/// A witness for the "none" value of optional constraints. +FIDDLE() +class NoneWitness : public Witness +{ + FIDDLE(...) + + void _toTextOverride(StringBuilder& out); + Val* _resolveImplOverride(); +}; + /// A value that represents a modifier attached to some other value FIDDLE() class ModifierVal : public Val diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index 71b87f447..ba1b8ea55 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -122,6 +122,8 @@ SubtypeWitness* SemanticsVisitor::checkAndConstructSubtypeWitness( ensureDecl(superDeclRefType->getDeclRef().getDecl(), DeclCheckState::ReadyForLookup); } + SubtypeWitness* failureWitness = nullptr; + // In the common case, we can use the pre-computed inheritance information for `subType` // to enumerate all the types it transitively inherits from. // @@ -148,9 +150,19 @@ SubtypeWitness* SemanticsVisitor::checkAndConstructSubtypeWitness( // If the `superType` appears in the flattened inheritance list // for the `subType`, then we know that the subtype relationship - // holds. Conveniently, the `facet` stores a pre-computed witness - // for the subtype relationship, which we can return here. - // + // holds. + + // If the witness is optional, we should only return it if no certain + // witness was found. + auto declWitness = as<DeclaredSubtypeWitness>(facet->subtypeWitness); + if (declWitness && declWitness->isOptional()) + { + failureWitness = facet->subtypeWitness; + continue; + } + + // Conveniently, the `facet` stores a pre-computed witness for the + // subtype relationship, which we can use here. return facet->subtypeWitness; } // @@ -271,7 +283,7 @@ SubtypeWitness* SemanticsVisitor::checkAndConstructSubtypeWitness( return m_astBuilder->getEachSubtypeWitness(subType, superType, elementWitness); } // default is failure - return nullptr; + return failureWitness; } bool SemanticsVisitor::isValidGenericConstraintType(Type* type) diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 3020554c8..7c55c440c 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -360,9 +360,12 @@ DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem( for (auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(m_astBuilder, genericDeclRef)) { + ValUnificationContext unificationContext; + unificationContext.optionalConstraint = + constraintDeclRef.getDecl()->hasModifier<OptionalConstraintModifier>(); if (!TryUnifyTypes( *system, - ValUnificationContext(), + unificationContext, getSub(m_astBuilder, constraintDeclRef), getSup(m_astBuilder, constraintDeclRef))) return DeclRef<Decl>(); @@ -487,8 +490,11 @@ DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem( auto joinType = TryJoinTypes(system, type, cType); if (!joinType) { - // failure! - return DeclRef<Decl>(); + if (c.isOptional) + joinType = type; + else + // failure! + return DeclRef<Decl>(); } type = QualType(joinType, type.isLeftValue || cType.isLeftValue); } @@ -696,12 +702,22 @@ DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem( subTypeWitness = nullptr; } - if (subTypeWitness) + bool witnessIsOptional = isWitnessUncheckedOptional(subTypeWitness); + bool constraintIsOptional = constraintDecl->hasModifier<OptionalConstraintModifier>(); + + if (subTypeWitness && (!witnessIsOptional || constraintIsOptional)) { // We found a witness, so it will become an (implicit) argument. args.add(subTypeWitness); outBaseCost += subTypeWitness->getOverloadResolutionCost(); } + else if (!subTypeWitness && constraintIsOptional) + { + // Optional witness failed to resolve; not an error. + auto noneWitness = m_astBuilder->getOrCreate<NoneWitness>(); + args.add(noneWitness); + outBaseCost += kConversionCost_FailedOptionalConstraint; + } else { // No witness was found, so the inference will now fail. @@ -851,13 +867,20 @@ bool SemanticsVisitor::TryUnifyVals( // Two subtype witnesses can be unified if they exist (non-null) and // prove that some pair of types are subtypes of types that can be unified. // - if (auto fstWit = as<SubtypeWitness>(fst)) - { - if (auto sndWit = as<SubtypeWitness>(snd)) - { - return TryUnifyTypes(constraints, unifyCtx, fstWit->getSup(), sndWit->getSup()); - } - } + const auto fstSubtypeWitness = as<SubtypeWitness>(fst); + const auto sndSubtypeWitness = as<SubtypeWitness>(snd); + const auto fstNoneWitness = as<NoneWitness>(fst); + const auto sndNoneWitness = as<NoneWitness>(snd); + if (fstSubtypeWitness && sndSubtypeWitness) + return TryUnifyTypes( + constraints, + unifyCtx, + fstSubtypeWitness->getSup(), + sndSubtypeWitness->getSup()); + else if (fstNoneWitness && sndNoneWitness) + return true; + else if ((fstNoneWitness && sndSubtypeWitness) || (fstSubtypeWitness && sndNoneWitness)) + return false; SLANG_UNIMPLEMENTED_X("value unification case"); @@ -946,6 +969,7 @@ bool SemanticsVisitor::TryUnifyTypeParam( constraint.indexInPack = unificationContext.indexInTypePack; constraint.val = type; constraint.isUsedAsLValue = type.isLeftValue; + constraint.isOptional = unificationContext.optionalConstraint; constraints.constraints.add(constraint); return true; diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 4c6bf98d2..306687bd8 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1023,6 +1023,100 @@ LookupResult SemanticsVisitor::filterLookupResultByVisibilityAndDiagnose( return result; } +bool SemanticsVisitor::isWitnessUncheckedOptional(SubtypeWitness* witness) +{ + auto declaredWitness = as<DeclaredSubtypeWitness>(witness); + if (!declaredWitness) + return false; + + auto decl = declaredWitness->getDeclRef().getDecl(); + if (!decl || !decl->hasModifier<OptionalConstraintModifier>()) + return false; + + // Okay, we've found an optional subtype witness. This result needs + // to be removed if we're not inside a block that directly checks + // if (sub is sup) + auto sub = witness->getSub(); + auto sup = witness->getSup(); + + for (auto outerStmtInfo = m_outerStmts; outerStmtInfo; outerStmtInfo = outerStmtInfo->next) + { + auto outerStmt = outerStmtInfo->stmt; + auto ifStmt = as<IfStmt>(outerStmt); + + if (!ifStmt) + continue; + + IsTypeExpr* isType = as<IsTypeExpr>(ifStmt->predicate); + if (!isType) + continue; + VarExpr* var = as<VarExpr>(isType->value); + if (!var) + continue; + TypeType* typeType = as<TypeType>(var->type); + + // var->type works for `variable is Interface`, while + // typeType->getType() is for `T is Interface`. + auto type = typeType ? typeType->getType() : var->type.type; + if (type == sub && isType->typeExpr.type == sup) + { + return false; + } + } + + // If we got this far, it's both an optional witness and there's no + // statement checking its validity. + return true; +} + +LookupResult SemanticsVisitor::filterLookupResultByCheckedOptional(const LookupResult& lookupResult) +{ + LookupResult filteredResult; + for (auto item : lookupResult) + { + bool optionalConstraintsChecked = true; + + for (auto bb = item.breadcrumbs; bb; bb = bb->next) + { + auto witness = as<SubtypeWitness>(bb->val); + if (!witness) + continue; + + if (isWitnessUncheckedOptional(witness)) + { + optionalConstraintsChecked = false; + break; + } + } + + if (optionalConstraintsChecked) + AddToLookupResult(filteredResult, item); + } + return filteredResult; +} + +LookupResult SemanticsVisitor::filterLookupResultByCheckedOptionalAndDiagnose( + const LookupResult& lookupResult, + SourceLoc loc, + bool& outDiagnosed) +{ + auto result = filterLookupResultByCheckedOptional(lookupResult); + if (lookupResult.isValid() && !result.isValid()) + { + getSink()->diagnose( + loc, + Diagnostics::requiredConstraintIsNotChecked, + lookupResult.item.declRef); + outDiagnosed = true; + + if (getShared()->isInLanguageServer()) + { + return lookupResult; + } + } + return result; +} + LookupResult SemanticsVisitor::resolveOverloadedLookup(LookupResult const& inResult) { // If the result isn't actually overloaded, it is fine as-is @@ -4068,19 +4162,16 @@ Expr* SemanticsExprVisitor::visitIsTypeExpr(IsTypeExpr* expr) expr->type = m_astBuilder->getBoolType(); expr->value = originalVal; - // Check if the right-hand side type is an interface type - if (isInterfaceType(expr->typeExpr.type)) - { - getSink()->diagnose(expr, Diagnostics::isAsOperatorCannotUseInterfaceAsRHS); - return expr; - } - auto valueType = expr->value->type.type; if (auto typeType = as<TypeType>(valueType)) valueType = typeType->getType(); // If value is a subtype of `type`, then this expr is always true. - if (isSubtype(valueType, expr->typeExpr.type, IsSubTypeOptions::None)) + auto witness = isSubtype(valueType, expr->typeExpr.type, IsSubTypeOptions::None); + auto declWitness = as<DeclaredSubtypeWitness>(witness); + bool optionalWitness = declWitness && declWitness->isOptional(); + + if (witness && !optionalWitness) { // Instead of returning a BoolLiteralExpr, we use a field to indicate this scenario, // so that the language server can still see the original syntax tree. @@ -4091,15 +4182,24 @@ Expr* SemanticsExprVisitor::visitIsTypeExpr(IsTypeExpr* expr) return expr; } + // Check if the right-hand side type is an interface type. For 'is' + // statements, that's only allowed if it's related to an optional + // constraint. + if (isInterfaceType(expr->typeExpr.type) && !optionalWitness) + { + getSink()->diagnose(expr, Diagnostics::isOperatorCannotUseInterfaceAsRHS); + return expr; + } + // Otherwise, if the target type is a subtype of value->type, we need to grab the // subtype witness for runtime checks. expr->value = maybeOpenExistential(originalVal); - expr->witnessArg = tryGetSubtypeWitness(expr->typeExpr.type, valueType); + expr->witnessArg = witness ? witness : tryGetSubtypeWitness(expr->typeExpr.type, valueType); if (expr->witnessArg) { // For now we can only support the scenario where `expr->value` is an interface type. - if (!isInterfaceType(originalVal->type)) + if (!optionalWitness && !isInterfaceType(originalVal->type)) { getSink()->diagnose(expr, Diagnostics::isOperatorValueMustBeInterfaceType); } @@ -4117,7 +4217,7 @@ Expr* SemanticsExprVisitor::visitAsTypeExpr(AsTypeExpr* expr) // Check if the right-hand side type is an interface type if (isInterfaceType(typeExpr.type)) { - getSink()->diagnose(expr, Diagnostics::isAsOperatorCannotUseInterfaceAsRHS); + getSink()->diagnose(expr, Diagnostics::asOperatorCannotUseInterfaceAsRHS); expr->type = m_astBuilder->getErrorType(); return expr; } @@ -5162,6 +5262,8 @@ Expr* SemanticsVisitor::checkGeneralMemberLookupExpr(MemberExpr* expr, Type* bas lookUpMember(m_astBuilder, this, expr->name, baseType, m_outerScope); bool diagnosed = false; lookupResult = filterLookupResultByVisibilityAndDiagnose(lookupResult, expr->loc, diagnosed); + lookupResult = + filterLookupResultByCheckedOptionalAndDiagnose(lookupResult, expr->loc, diagnosed); if (!lookupResult.isValid()) { return lookupMemberResultFailure(expr, baseType, diagnosed); diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index d4d431914..e73f65d75 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1372,6 +1372,13 @@ public: SourceLoc loc, bool& outDiagnosed); + bool isWitnessUncheckedOptional(SubtypeWitness* witness); + LookupResult filterLookupResultByCheckedOptional(const LookupResult& lookupResult); + LookupResult filterLookupResultByCheckedOptionalAndDiagnose( + const LookupResult& lookupResult, + SourceLoc loc, + bool& outDiagnosed); + Val* resolveVal(Val* val) { if (!val) @@ -2650,6 +2657,7 @@ public: struct ValUnificationContext { Index indexInTypePack = 0; + bool optionalConstraint = false; }; // Try to find a unification for two values diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 7a89f2e23..6c0a7f184 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -994,10 +994,18 @@ bool SemanticsVisitor::TryCheckOverloadCandidateConstraints( auto sup = getSup(m_astBuilder, constraintDeclRef); auto subTypeWitness = tryGetSubtypeWitness(sub, sup); - if (subTypeWitness) + + bool witnessIsOptional = isWitnessUncheckedOptional(subTypeWitness); + bool constraintIsOptional = constraintDecl->hasModifier<OptionalConstraintModifier>(); + + if (subTypeWitness && (!witnessIsOptional || constraintIsOptional)) { newArgs.add(subTypeWitness); } + else if (!subTypeWitness && constraintIsOptional) + { + newArgs.add(m_astBuilder->getOrCreate<NoneWitness>()); + } else { if (context.mode != OverloadResolveContext::Mode::JustTrying) diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index faf25b716..bf8ddc94e 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -522,9 +522,10 @@ void SemanticsStmtVisitor::visitDefaultStmt(DefaultStmt* stmt) void SemanticsStmtVisitor::visitIfStmt(IfStmt* stmt) { + WithOuterStmt subContext(this, stmt); stmt->predicate = checkPredicateExpr(stmt->predicate); - checkStmt(stmt->positiveStatement); - checkStmt(stmt->negativeStatement); + subContext.checkStmt(stmt->positiveStatement); + subContext.checkStmt(stmt->negativeStatement); } void SemanticsStmtVisitor::visitUnparsedStmt(UnparsedStmt*) diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index d54e1a3e0..384a81f9b 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -845,9 +845,18 @@ DIAGNOSTIC( DIAGNOSTIC( 30301, Error, - isAsOperatorCannotUseInterfaceAsRHS, - "'is' and 'as' operators do not support interface types as the right-hand side. Use a concrete " - "type instead.") + isOperatorCannotUseInterfaceAsRHS, + "cannot use 'is' operator with an interface type as the right-hand " + "side without a corresponding optional constraint. Use a concrete type " + "instead, or add an optional constraint for the interface type.") + +DIAGNOSTIC( + 30302, + Error, + asOperatorCannotUseInterfaceAsRHS, + "cannot use 'as' operator with an interface type as the right-hand " + "side. Use a concrete type instead. If you want to use an optional " + "constraint, use an 'if (T is IInterface)' block instead.") DIAGNOSTIC(33070, Error, expectedFunction, "expected a function, got '$0'") @@ -1610,6 +1619,12 @@ DIAGNOSTIC( Error, invalidConstraintSubType, "type '$0' is not a valid left hand side of a type constraint.") +DIAGNOSTIC( + 30403, + Error, + requiredConstraintIsNotChecked, + "the constraint providing '$0' is optional and must be checked with an 'is' statement before " + "usage.") // 305xx: initializer lists DIAGNOSTIC(30500, Error, tooManyInitializers, "too many initializers (expected $0, got $1)") diff --git a/source/slang/slang-ir-lower-generic-call.cpp b/source/slang/slang-ir-lower-generic-call.cpp index 80b258407..41a78d00c 100644 --- a/source/slang/slang-ir-lower-generic-call.cpp +++ b/source/slang/slang-ir-lower-generic-call.cpp @@ -295,9 +295,17 @@ struct GenericCallLoweringContext return; } - auto interfaceType = cast<IRInterfaceType>( + auto interfaceType = as<IRInterfaceType>( cast<IRWitnessTableTypeBase>(lookupInst->getWitnessTable()->getDataType()) ->getConformanceType()); + + if (!interfaceType) + { + // NoneWitness -> remove call. + callInst->removeAndDeallocate(); + return; + } + if (isBuiltin(interfaceType)) return; diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp index f2b1d1a6a..c003d6125 100644 --- a/source/slang/slang-ir-lower-generic-function.cpp +++ b/source/slang/slang-ir-lower-generic-function.cpp @@ -296,8 +296,11 @@ struct GenericFunctionLoweringContext // and emission of wrapper functions. void lowerWitnessTable(IRWitnessTable* witnessTable) { - auto interfaceType = - maybeLowerInterfaceType(cast<IRInterfaceType>(witnessTable->getConformanceType())); + IRInterfaceType* conformanceType = as<IRInterfaceType>(witnessTable->getConformanceType()); + if (!conformanceType) + return; + + auto interfaceType = maybeLowerInterfaceType(conformanceType); IRBuilder builderStorage(sharedContext->module); auto builder = &builderStorage; builder->setInsertBefore(witnessTable); @@ -353,8 +356,17 @@ struct GenericFunctionLoweringContext return; if (witnessTableType->getConformanceType()->findDecoration<IRComInterfaceDecoration>()) return; - auto interfaceType = - maybeLowerInterfaceType(cast<IRInterfaceType>(witnessTableType->getConformanceType())); + + IRInterfaceType* conformanceType = + as<IRInterfaceType>(witnessTableType->getConformanceType()); + + // NoneWitness generates conformance types which aren't interfaces. In + // that case, the method can just be skipped entirely, since there's no + // real witness for it and it should be in unreachable code at this + // point. + if (!conformanceType) + return; + auto interfaceType = maybeLowerInterfaceType(conformanceType); interfaceRequirementVal = sharedContext->findInterfaceRequirementVal( interfaceType, lookupInst->getRequirementKey()); diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp index e4d1af93a..0c8b248b0 100644 --- a/source/slang/slang-ir-specialize-dispatch.cpp +++ b/source/slang/slang-ir-specialize-dispatch.cpp @@ -232,23 +232,31 @@ void ensureWitnessTableSequentialIDs(SharedGenericsLoweringContext* sharedContex { auto interfaceType = cast<IRWitnessTableType>(inst->getDataType())->getConformanceType(); - auto interfaceLinkage = interfaceType->findDecoration<IRLinkageDecoration>(); - SLANG_ASSERT( - interfaceLinkage && "An interface type does not have a linkage," - "but a witness table associated with it has one."); - auto interfaceName = interfaceLinkage->getMangledName(); - auto idAllocator = - linkage->mapInterfaceMangledNameToSequentialIDCounters.tryGetValue( - interfaceName); - if (!idAllocator) + if (as<IRInterfaceType>(interfaceType)) { - linkage->mapInterfaceMangledNameToSequentialIDCounters[interfaceName] = 0; - idAllocator = + auto interfaceLinkage = interfaceType->findDecoration<IRLinkageDecoration>(); + SLANG_ASSERT( + interfaceLinkage && "An interface type does not have a linkage," + "but a witness table associated with it has one."); + auto interfaceName = interfaceLinkage->getMangledName(); + auto idAllocator = linkage->mapInterfaceMangledNameToSequentialIDCounters.tryGetValue( interfaceName); + if (!idAllocator) + { + linkage->mapInterfaceMangledNameToSequentialIDCounters[interfaceName] = 0; + idAllocator = + linkage->mapInterfaceMangledNameToSequentialIDCounters.tryGetValue( + interfaceName); + } + seqID = *idAllocator; + ++(*idAllocator); + } + else + { + // NoneWitness, has special ID of -1. + seqID = uint32_t(-1); } - seqID = *idAllocator; - ++(*idAllocator); linkage->mapMangledNameToRTTIObjectIndex[witnessTableMangledName] = seqID; } diff --git a/source/slang/slang-ir-witness-table-wrapper.cpp b/source/slang/slang-ir-witness-table-wrapper.cpp index fabfd1611..cfda6225e 100644 --- a/source/slang/slang-ir-witness-table-wrapper.cpp +++ b/source/slang/slang-ir-witness-table-wrapper.cpp @@ -171,7 +171,9 @@ struct GenerateWitnessTableWrapperContext void lowerWitnessTable(IRWitnessTable* witnessTable) { - auto interfaceType = cast<IRInterfaceType>(witnessTable->getConformanceType()); + auto interfaceType = as<IRInterfaceType>(witnessTable->getConformanceType()); + if (!interfaceType) + return; if (isBuiltin(interfaceType)) return; if (isComInterfaceType(interfaceType)) diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index df3507670..8e13a3dc5 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1990,6 +1990,12 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower builder->emitGetTupleElement(elementType, conjunctionWitness, indexInConjunction)); } + LoweredValInfo visitNoneWitness(NoneWitness*) + { + auto builder = getBuilder(); + auto voidType = builder->getVoidType(); + return LoweredValInfo::simple(builder->createWitnessTable(voidType, voidType)); + } LoweredValInfo visitConstantIntVal(ConstantIntVal* val) { @@ -5472,23 +5478,37 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo> LoweredValInfo visitIsTypeExpr(IsTypeExpr* expr) { + auto builder = getBuilder(); if (expr->constantVal) { - return LoweredValInfo::simple(getBuilder()->getBoolValue(expr->constantVal->value)); + return LoweredValInfo::simple(builder->getBoolValue(expr->constantVal->value)); } - // If expr is a witness, then this is a run-time type check from for an existential type. if (expr->witnessArg) { - auto value = lowerLValueExpr(context, expr->value); auto type = lowerType(context, expr->typeExpr.type); auto witness = lowerSimpleVal(context, expr->witnessArg); - auto existentialInfo = value.getExtractedExistentialValInfo(); - auto irVal = getBuilder()->emitIsType( - existentialInfo->extractedVal, - existentialInfo->witnessTable, - type, - witness); - return LoweredValInfo::simple(irVal); + auto declWitness = as<DeclaredSubtypeWitness>(expr->witnessArg); + + if (declWitness && declWitness->isOptional()) + { + // Optional constraint check. NoneWitness lowers to a specific + // ID, so that we can check for that here. + auto witnessID = builder->emitGetSequentialIDInst(witness); + auto noneWitnessID = builder->getIntValue(builder->getUIntType(), -1); + auto irVal = builder->emitNeq(witnessID, noneWitnessID); + return LoweredValInfo::simple(irVal); + } + else + { // This is a run-time type check from for an existential type. + auto value = lowerLValueExpr(context, expr->value); + auto existentialInfo = value.getExtractedExistentialValInfo(); + auto irVal = builder->emitIsType( + existentialInfo->extractedVal, + existentialInfo->witnessTable, + type, + witness); + return LoweredValInfo::simple(irVal); + } } // For all other cases, we map to a simple type equality check in the IR. IRType* leftType = nullptr; @@ -5502,8 +5522,7 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo> } auto rightType = lowerType(context, expr->typeExpr.type); IRInst* args[] = {leftType, rightType}; - auto irVal = - getBuilder()->emitIntrinsicInst(getBuilder()->getBoolType(), kIROp_TypeEquals, 2, args); + auto irVal = builder->emitIntrinsicInst(builder->getBoolType(), kIROp_TypeEquals, 2, args); return LoweredValInfo::simple(irVal); } diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index a019f97c4..a2fd944eb 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -1692,6 +1692,8 @@ static void maybeParseGenericConstraints(Parser* parser, ContainerDecl* genericP Token whereToken; while (AdvanceIf(parser, "where", &whereToken)) { + bool optional = AdvanceIf(parser, "optional", &whereToken); + auto subType = parser->ParseTypeExp(); if (AdvanceIf(parser, TokenType::Colon)) { @@ -1702,6 +1704,12 @@ static void maybeParseGenericConstraints(Parser* parser, ContainerDecl* genericP parser->FillPosition(constraint); constraint->sub = subType; constraint->sup = parser->ParseTypeExp(); + if (optional) + { + addModifier( + constraint, + parser->astBuilder->create<OptionalConstraintModifier>()); + } AddMember(genericParent, constraint); if (!AdvanceIf(parser, TokenType::Comma)) break; @@ -1715,6 +1723,10 @@ static void maybeParseGenericConstraints(Parser* parser, ContainerDecl* genericP parser->FillPosition(constraint); constraint->sub = subType; constraint->sup = parser->ParseTypeExp(); + if (optional) + { + addModifier(constraint, parser->astBuilder->create<OptionalConstraintModifier>()); + } AddMember(genericParent, constraint); } else if (AdvanceIf(parser, TokenType::LParent)) diff --git a/tests/language-feature/generics/where-optional-1.slang b/tests/language-feature/generics/where-optional-1.slang new file mode 100644 index 000000000..da4bdaacb --- /dev/null +++ b/tests/language-feature/generics/where-optional-1.slang @@ -0,0 +1,11 @@ +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): +interface IThing +{ + void thing(); +} + +void f<T>(T t) where optional T: IThing +{ + // Unchecked optional constraint is an error. + t.thing(); // CHECK: error 30403 +} diff --git a/tests/language-feature/generics/where-optional-2.slang b/tests/language-feature/generics/where-optional-2.slang new file mode 100644 index 000000000..67679bd8d --- /dev/null +++ b/tests/language-feature/generics/where-optional-2.slang @@ -0,0 +1,63 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cpu -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +interface IThing +{ + [mutating] + int thing(int index); +} + +struct MyThing: IThing +{ + int val; + + [mutating] + int thing(int index) + { + val++; + outputBuffer[index] = val; + return val; + } +} + +struct NotMyThing +{ + int val; +} + +void f<T>(inout T t, int index) where optional T: IThing +{ + if (T is IThing) + { + outputBuffer[index+1] = 2 * t.thing(index); + } + else + { + outputBuffer[index] = 0; + outputBuffer[index+1] = 0; + } +} + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + MyThing mt = MyThing(0); + NotMyThing nt = NotMyThing(1); + + // CHECK: 1 + // CHECK-NEXT: 2 + f<MyThing>(mt, 0); + // CHECK-NEXT: 0 + // CHECK-NEXT: 0 + f<NotMyThing>(nt, 2); + // CHECK: 2 + // CHECK-NEXT: 4 + f(mt, 4); + // CHECK-NEXT: 0 + // CHECK-NEXT: 0 + f(nt, 6); +} diff --git a/tests/language-feature/generics/where-optional-3.slang b/tests/language-feature/generics/where-optional-3.slang new file mode 100644 index 000000000..f8b3a4907 --- /dev/null +++ b/tests/language-feature/generics/where-optional-3.slang @@ -0,0 +1,95 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cpu -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +interface IReleaseable +{ + [mutating] + void release(); +} + +struct Container<K, V> + where optional K : IReleaseable + where optional V : IReleaseable +{ + K k[2]; + V v[2]; + + [mutating] + void erase(int index) + { + if (K is IReleaseable) + k[index].release(); + if (V is IReleaseable) + v[index].release(); + } +} + +struct HeavyEntry: IReleaseable +{ + int index; + int value; + + [mutating] + void release() + { + outputBuffer[index] = value; + } +}; + +struct LightEntry +{ + int value; +}; + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + { // Neither is IReleaseable + var c = Container<LightEntry, LightEntry>(); + c.k[0] = LightEntry(1); + c.k[1] = LightEntry(2); + c.v[0] = LightEntry(3); + c.v[1] = LightEntry(4); + c.erase(0); + c.erase(1); + } + { // K is IReleaseable + var c = Container<HeavyEntry, LightEntry>(); + c.k[0] = HeavyEntry(0,1); + c.k[1] = HeavyEntry(1,2); + c.v[0] = LightEntry(3); + c.v[1] = LightEntry(4); + // CHECK: 1 + c.erase(0); + // CHECK-NEXT: 2 + c.erase(1); + } + { // V is IReleaseable + var c = Container<LightEntry, HeavyEntry>(); + c.k[0] = LightEntry(1); + c.k[1] = LightEntry(2); + c.v[0] = HeavyEntry(2,3); + c.v[1] = HeavyEntry(3,4); + // CHECK-NEXT: 3 + c.erase(0); + // CHECK-NEXT: 4 + c.erase(1); + } + { // K and V are IReleaseable + var c = Container<HeavyEntry, HeavyEntry>(); + c.k[0] = HeavyEntry(4,5); + c.k[1] = HeavyEntry(6,7); + c.v[0] = HeavyEntry(5,6); + c.v[1] = HeavyEntry(7,8); + // CHECK-NEXT: 5 + // CHECK-NEXT: 6 + c.erase(0); + // CHECK-NEXT: 7 + // CHECK-NEXT: 8 + c.erase(1); + } +} diff --git a/tests/language-feature/generics/where-optional-4.slang b/tests/language-feature/generics/where-optional-4.slang new file mode 100644 index 000000000..6d72186d9 --- /dev/null +++ b/tests/language-feature/generics/where-optional-4.slang @@ -0,0 +1,15 @@ +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): +interface IThing +{ + void thing(); +} + +void g<T>(T t) where T: IThing +{ +} + +void f<T>(T t) where optional T: IThing +{ + // Error: cannot upgrade optional to non-optional witness in unchecked context. + g<T>(t); // CHECK: error 38029 +} diff --git a/tests/language-feature/generics/where-optional-5.slang b/tests/language-feature/generics/where-optional-5.slang new file mode 100644 index 000000000..3ce8041ad --- /dev/null +++ b/tests/language-feature/generics/where-optional-5.slang @@ -0,0 +1,17 @@ +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): +interface IThing +{ + void thing(); +} + +void f<T, U>(T t, U u) + where optional T: IThing + where optional U: IThing +{ + // Error: cannot upgrade optional to non-optional witness in unchecked context. + if (U is IThing) + { + // U being IThing doesn't justify using T as such! + t.thing(); // CHECK: error 30403 + } +} diff --git a/tests/language-feature/interface-as-rhs-error.slang b/tests/language-feature/interface-as-rhs-error.slang index 9ad71afde..7293b5134 100644 --- a/tests/language-feature/interface-as-rhs-error.slang +++ b/tests/language-feature/interface-as-rhs-error.slang @@ -21,13 +21,13 @@ struct AnotherType // These should produce errors - interface types as RHS bool testIsOperatorWithInterface<T>() { - //CHECK: ([[# @LINE+1]]): error 30301: 'is' and 'as' operators do not support interface types as the right-hand side + //CHECK: ([[# @LINE+1]]): error 30301: cannot use 'is' operator with an interface type as the right-hand side return (T is IMyInterface); } void testAsOperatorWithInterface<T>(T value) { - //CHECK: ([[# @LINE+1]]): error 30301: 'is' and 'as' operators do not support interface types as the right-hand side + //CHECK: ([[# @LINE+1]]): error 30302: cannot use 'as' operator with an interface type as the right-hand side let result = value as IMyInterface; } |
