summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--docs/user-guide/06-interfaces-generics.md22
-rw-r--r--source/slang/slang-ast-modifier.h7
-rw-r--r--source/slang/slang-ast-support-types.h1
-rw-r--r--source/slang/slang-ast-val.cpp12
-rw-r--r--source/slang/slang-ast-val.h17
-rw-r--r--source/slang/slang-check-conformance.cpp20
-rw-r--r--source/slang/slang-check-constraint.cpp46
-rw-r--r--source/slang/slang-check-expr.cpp124
-rw-r--r--source/slang/slang-check-impl.h8
-rw-r--r--source/slang/slang-check-overload.cpp10
-rw-r--r--source/slang/slang-check-stmt.cpp5
-rw-r--r--source/slang/slang-diagnostic-defs.h21
-rw-r--r--source/slang/slang-ir-lower-generic-call.cpp10
-rw-r--r--source/slang/slang-ir-lower-generic-function.cpp20
-rw-r--r--source/slang/slang-ir-specialize-dispatch.cpp34
-rw-r--r--source/slang/slang-ir-witness-table-wrapper.cpp4
-rw-r--r--source/slang/slang-lower-to-ir.cpp43
-rw-r--r--source/slang/slang-parser.cpp12
-rw-r--r--tests/language-feature/generics/where-optional-1.slang11
-rw-r--r--tests/language-feature/generics/where-optional-2.slang63
-rw-r--r--tests/language-feature/generics/where-optional-3.slang95
-rw-r--r--tests/language-feature/generics/where-optional-4.slang15
-rw-r--r--tests/language-feature/generics/where-optional-5.slang17
-rw-r--r--tests/language-feature/interface-as-rhs-error.slang4
24 files changed, 556 insertions, 65 deletions
diff --git a/docs/user-guide/06-interfaces-generics.md b/docs/user-guide/06-interfaces-generics.md
index bb19fd776..7c957f685 100644
--- a/docs/user-guide/06-interfaces-generics.md
+++ b/docs/user-guide/06-interfaces-generics.md
@@ -144,6 +144,28 @@ struct MyType<T, U>
}
```
+Optional conformances can be expressed compactly using the `where optional` syntax:
+```csharp
+// Together, these two overloads...
+int myGenericMethod<T>(T arg)
+{
+}
+
+int myGenericMethod<T>(T arg) where T: IFoo
+{
+ arg.myMethod(1.0);
+}
+
+// ... are equivalent to:
+int myGenericMethod<T>(T arg) where optional T: IFoo
+{
+ if (T is IFoo)
+ {
+ arg.myMethod(1.0); // OK in a block that checks for T: IFoo conformance.
+ }
+}
+```
+
Supported Constructs in Interface Definitions
-----------------------------------------------------
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index ca023a19b..88dea0b7e 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -219,6 +219,13 @@ class IgnoreForLookupModifier : public Modifier
FIDDLE(...)
};
+/// A modifier that indicates an `TypeConstraintDecl` is optional.
+FIDDLE()
+class OptionalConstraintModifier : public Modifier
+{
+ FIDDLE(...)
+};
+
// A modifier that marks something as an operation that
// has a one-to-one translation to the IR, and thus
// has no direct definition in the high-level language.
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index e4002c237..572d05d9f 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -123,6 +123,7 @@ FIDDLE() namespace Slang
kConversionCost_ValToOptional = 150,
kConversionCost_NullPtrToPtr = 150,
kConversionCost_PtrToVoidPtr = 150,
+ kConversionCost_FailedOptionalConstraint = 150,
// Conversions that are lossless, but change "kind"
kConversionCost_UnsignedToSignedPromotion = 200,
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index 40daae3a6..0e942fe39 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -896,6 +896,18 @@ Val* TypeCoercionWitness::_resolveImplOverride()
return this;
}
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! NoneWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+void NoneWitness::_toTextOverride(StringBuilder& out)
+{
+ out.append("none");
+}
+
+Val* NoneWitness::_resolveImplOverride()
+{
+ return this;
+}
+
// UNormModifierVal
void UNormModifierVal::_toTextOverride(StringBuilder& out)
diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h
index a8b969a94..f371d76fe 100644
--- a/source/slang/slang-ast-val.h
+++ b/source/slang/slang-ast-val.h
@@ -697,6 +697,13 @@ class DeclaredSubtypeWitness : public SubtypeWitness
return false;
}
+ bool isOptional()
+ {
+ if (auto declRef = getDeclRef().as<GenericTypeConstraintDecl>())
+ return declRef.getDecl()->hasModifier<OptionalConstraintModifier>();
+ return false;
+ }
+
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
@@ -832,6 +839,16 @@ class ExtractFromConjunctionSubtypeWitness : public SubtypeWitness
ConversionCost _getOverloadResolutionCostOverride();
};
+/// A witness for the "none" value of optional constraints.
+FIDDLE()
+class NoneWitness : public Witness
+{
+ FIDDLE(...)
+
+ void _toTextOverride(StringBuilder& out);
+ Val* _resolveImplOverride();
+};
+
/// A value that represents a modifier attached to some other value
FIDDLE()
class ModifierVal : public Val
diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp
index 71b87f447..ba1b8ea55 100644
--- a/source/slang/slang-check-conformance.cpp
+++ b/source/slang/slang-check-conformance.cpp
@@ -122,6 +122,8 @@ SubtypeWitness* SemanticsVisitor::checkAndConstructSubtypeWitness(
ensureDecl(superDeclRefType->getDeclRef().getDecl(), DeclCheckState::ReadyForLookup);
}
+ SubtypeWitness* failureWitness = nullptr;
+
// In the common case, we can use the pre-computed inheritance information for `subType`
// to enumerate all the types it transitively inherits from.
//
@@ -148,9 +150,19 @@ SubtypeWitness* SemanticsVisitor::checkAndConstructSubtypeWitness(
// If the `superType` appears in the flattened inheritance list
// for the `subType`, then we know that the subtype relationship
- // holds. Conveniently, the `facet` stores a pre-computed witness
- // for the subtype relationship, which we can return here.
- //
+ // holds.
+
+ // If the witness is optional, we should only return it if no certain
+ // witness was found.
+ auto declWitness = as<DeclaredSubtypeWitness>(facet->subtypeWitness);
+ if (declWitness && declWitness->isOptional())
+ {
+ failureWitness = facet->subtypeWitness;
+ continue;
+ }
+
+ // Conveniently, the `facet` stores a pre-computed witness for the
+ // subtype relationship, which we can use here.
return facet->subtypeWitness;
}
//
@@ -271,7 +283,7 @@ SubtypeWitness* SemanticsVisitor::checkAndConstructSubtypeWitness(
return m_astBuilder->getEachSubtypeWitness(subType, superType, elementWitness);
}
// default is failure
- return nullptr;
+ return failureWitness;
}
bool SemanticsVisitor::isValidGenericConstraintType(Type* type)
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp
index 3020554c8..7c55c440c 100644
--- a/source/slang/slang-check-constraint.cpp
+++ b/source/slang/slang-check-constraint.cpp
@@ -360,9 +360,12 @@ DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem(
for (auto constraintDeclRef :
getMembersOfType<GenericTypeConstraintDecl>(m_astBuilder, genericDeclRef))
{
+ ValUnificationContext unificationContext;
+ unificationContext.optionalConstraint =
+ constraintDeclRef.getDecl()->hasModifier<OptionalConstraintModifier>();
if (!TryUnifyTypes(
*system,
- ValUnificationContext(),
+ unificationContext,
getSub(m_astBuilder, constraintDeclRef),
getSup(m_astBuilder, constraintDeclRef)))
return DeclRef<Decl>();
@@ -487,8 +490,11 @@ DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem(
auto joinType = TryJoinTypes(system, type, cType);
if (!joinType)
{
- // failure!
- return DeclRef<Decl>();
+ if (c.isOptional)
+ joinType = type;
+ else
+ // failure!
+ return DeclRef<Decl>();
}
type = QualType(joinType, type.isLeftValue || cType.isLeftValue);
}
@@ -696,12 +702,22 @@ DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem(
subTypeWitness = nullptr;
}
- if (subTypeWitness)
+ bool witnessIsOptional = isWitnessUncheckedOptional(subTypeWitness);
+ bool constraintIsOptional = constraintDecl->hasModifier<OptionalConstraintModifier>();
+
+ if (subTypeWitness && (!witnessIsOptional || constraintIsOptional))
{
// We found a witness, so it will become an (implicit) argument.
args.add(subTypeWitness);
outBaseCost += subTypeWitness->getOverloadResolutionCost();
}
+ else if (!subTypeWitness && constraintIsOptional)
+ {
+ // Optional witness failed to resolve; not an error.
+ auto noneWitness = m_astBuilder->getOrCreate<NoneWitness>();
+ args.add(noneWitness);
+ outBaseCost += kConversionCost_FailedOptionalConstraint;
+ }
else
{
// No witness was found, so the inference will now fail.
@@ -851,13 +867,20 @@ bool SemanticsVisitor::TryUnifyVals(
// Two subtype witnesses can be unified if they exist (non-null) and
// prove that some pair of types are subtypes of types that can be unified.
//
- if (auto fstWit = as<SubtypeWitness>(fst))
- {
- if (auto sndWit = as<SubtypeWitness>(snd))
- {
- return TryUnifyTypes(constraints, unifyCtx, fstWit->getSup(), sndWit->getSup());
- }
- }
+ const auto fstSubtypeWitness = as<SubtypeWitness>(fst);
+ const auto sndSubtypeWitness = as<SubtypeWitness>(snd);
+ const auto fstNoneWitness = as<NoneWitness>(fst);
+ const auto sndNoneWitness = as<NoneWitness>(snd);
+ if (fstSubtypeWitness && sndSubtypeWitness)
+ return TryUnifyTypes(
+ constraints,
+ unifyCtx,
+ fstSubtypeWitness->getSup(),
+ sndSubtypeWitness->getSup());
+ else if (fstNoneWitness && sndNoneWitness)
+ return true;
+ else if ((fstNoneWitness && sndSubtypeWitness) || (fstSubtypeWitness && sndNoneWitness))
+ return false;
SLANG_UNIMPLEMENTED_X("value unification case");
@@ -946,6 +969,7 @@ bool SemanticsVisitor::TryUnifyTypeParam(
constraint.indexInPack = unificationContext.indexInTypePack;
constraint.val = type;
constraint.isUsedAsLValue = type.isLeftValue;
+ constraint.isOptional = unificationContext.optionalConstraint;
constraints.constraints.add(constraint);
return true;
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 4c6bf98d2..306687bd8 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1023,6 +1023,100 @@ LookupResult SemanticsVisitor::filterLookupResultByVisibilityAndDiagnose(
return result;
}
+bool SemanticsVisitor::isWitnessUncheckedOptional(SubtypeWitness* witness)
+{
+ auto declaredWitness = as<DeclaredSubtypeWitness>(witness);
+ if (!declaredWitness)
+ return false;
+
+ auto decl = declaredWitness->getDeclRef().getDecl();
+ if (!decl || !decl->hasModifier<OptionalConstraintModifier>())
+ return false;
+
+ // Okay, we've found an optional subtype witness. This result needs
+ // to be removed if we're not inside a block that directly checks
+ // if (sub is sup)
+ auto sub = witness->getSub();
+ auto sup = witness->getSup();
+
+ for (auto outerStmtInfo = m_outerStmts; outerStmtInfo; outerStmtInfo = outerStmtInfo->next)
+ {
+ auto outerStmt = outerStmtInfo->stmt;
+ auto ifStmt = as<IfStmt>(outerStmt);
+
+ if (!ifStmt)
+ continue;
+
+ IsTypeExpr* isType = as<IsTypeExpr>(ifStmt->predicate);
+ if (!isType)
+ continue;
+ VarExpr* var = as<VarExpr>(isType->value);
+ if (!var)
+ continue;
+ TypeType* typeType = as<TypeType>(var->type);
+
+ // var->type works for `variable is Interface`, while
+ // typeType->getType() is for `T is Interface`.
+ auto type = typeType ? typeType->getType() : var->type.type;
+ if (type == sub && isType->typeExpr.type == sup)
+ {
+ return false;
+ }
+ }
+
+ // If we got this far, it's both an optional witness and there's no
+ // statement checking its validity.
+ return true;
+}
+
+LookupResult SemanticsVisitor::filterLookupResultByCheckedOptional(const LookupResult& lookupResult)
+{
+ LookupResult filteredResult;
+ for (auto item : lookupResult)
+ {
+ bool optionalConstraintsChecked = true;
+
+ for (auto bb = item.breadcrumbs; bb; bb = bb->next)
+ {
+ auto witness = as<SubtypeWitness>(bb->val);
+ if (!witness)
+ continue;
+
+ if (isWitnessUncheckedOptional(witness))
+ {
+ optionalConstraintsChecked = false;
+ break;
+ }
+ }
+
+ if (optionalConstraintsChecked)
+ AddToLookupResult(filteredResult, item);
+ }
+ return filteredResult;
+}
+
+LookupResult SemanticsVisitor::filterLookupResultByCheckedOptionalAndDiagnose(
+ const LookupResult& lookupResult,
+ SourceLoc loc,
+ bool& outDiagnosed)
+{
+ auto result = filterLookupResultByCheckedOptional(lookupResult);
+ if (lookupResult.isValid() && !result.isValid())
+ {
+ getSink()->diagnose(
+ loc,
+ Diagnostics::requiredConstraintIsNotChecked,
+ lookupResult.item.declRef);
+ outDiagnosed = true;
+
+ if (getShared()->isInLanguageServer())
+ {
+ return lookupResult;
+ }
+ }
+ return result;
+}
+
LookupResult SemanticsVisitor::resolveOverloadedLookup(LookupResult const& inResult)
{
// If the result isn't actually overloaded, it is fine as-is
@@ -4068,19 +4162,16 @@ Expr* SemanticsExprVisitor::visitIsTypeExpr(IsTypeExpr* expr)
expr->type = m_astBuilder->getBoolType();
expr->value = originalVal;
- // Check if the right-hand side type is an interface type
- if (isInterfaceType(expr->typeExpr.type))
- {
- getSink()->diagnose(expr, Diagnostics::isAsOperatorCannotUseInterfaceAsRHS);
- return expr;
- }
-
auto valueType = expr->value->type.type;
if (auto typeType = as<TypeType>(valueType))
valueType = typeType->getType();
// If value is a subtype of `type`, then this expr is always true.
- if (isSubtype(valueType, expr->typeExpr.type, IsSubTypeOptions::None))
+ auto witness = isSubtype(valueType, expr->typeExpr.type, IsSubTypeOptions::None);
+ auto declWitness = as<DeclaredSubtypeWitness>(witness);
+ bool optionalWitness = declWitness && declWitness->isOptional();
+
+ if (witness && !optionalWitness)
{
// Instead of returning a BoolLiteralExpr, we use a field to indicate this scenario,
// so that the language server can still see the original syntax tree.
@@ -4091,15 +4182,24 @@ Expr* SemanticsExprVisitor::visitIsTypeExpr(IsTypeExpr* expr)
return expr;
}
+ // Check if the right-hand side type is an interface type. For 'is'
+ // statements, that's only allowed if it's related to an optional
+ // constraint.
+ if (isInterfaceType(expr->typeExpr.type) && !optionalWitness)
+ {
+ getSink()->diagnose(expr, Diagnostics::isOperatorCannotUseInterfaceAsRHS);
+ return expr;
+ }
+
// Otherwise, if the target type is a subtype of value->type, we need to grab the
// subtype witness for runtime checks.
expr->value = maybeOpenExistential(originalVal);
- expr->witnessArg = tryGetSubtypeWitness(expr->typeExpr.type, valueType);
+ expr->witnessArg = witness ? witness : tryGetSubtypeWitness(expr->typeExpr.type, valueType);
if (expr->witnessArg)
{
// For now we can only support the scenario where `expr->value` is an interface type.
- if (!isInterfaceType(originalVal->type))
+ if (!optionalWitness && !isInterfaceType(originalVal->type))
{
getSink()->diagnose(expr, Diagnostics::isOperatorValueMustBeInterfaceType);
}
@@ -4117,7 +4217,7 @@ Expr* SemanticsExprVisitor::visitAsTypeExpr(AsTypeExpr* expr)
// Check if the right-hand side type is an interface type
if (isInterfaceType(typeExpr.type))
{
- getSink()->diagnose(expr, Diagnostics::isAsOperatorCannotUseInterfaceAsRHS);
+ getSink()->diagnose(expr, Diagnostics::asOperatorCannotUseInterfaceAsRHS);
expr->type = m_astBuilder->getErrorType();
return expr;
}
@@ -5162,6 +5262,8 @@ Expr* SemanticsVisitor::checkGeneralMemberLookupExpr(MemberExpr* expr, Type* bas
lookUpMember(m_astBuilder, this, expr->name, baseType, m_outerScope);
bool diagnosed = false;
lookupResult = filterLookupResultByVisibilityAndDiagnose(lookupResult, expr->loc, diagnosed);
+ lookupResult =
+ filterLookupResultByCheckedOptionalAndDiagnose(lookupResult, expr->loc, diagnosed);
if (!lookupResult.isValid())
{
return lookupMemberResultFailure(expr, baseType, diagnosed);
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index d4d431914..e73f65d75 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1372,6 +1372,13 @@ public:
SourceLoc loc,
bool& outDiagnosed);
+ bool isWitnessUncheckedOptional(SubtypeWitness* witness);
+ LookupResult filterLookupResultByCheckedOptional(const LookupResult& lookupResult);
+ LookupResult filterLookupResultByCheckedOptionalAndDiagnose(
+ const LookupResult& lookupResult,
+ SourceLoc loc,
+ bool& outDiagnosed);
+
Val* resolveVal(Val* val)
{
if (!val)
@@ -2650,6 +2657,7 @@ public:
struct ValUnificationContext
{
Index indexInTypePack = 0;
+ bool optionalConstraint = false;
};
// Try to find a unification for two values
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index 7a89f2e23..6c0a7f184 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -994,10 +994,18 @@ bool SemanticsVisitor::TryCheckOverloadCandidateConstraints(
auto sup = getSup(m_astBuilder, constraintDeclRef);
auto subTypeWitness = tryGetSubtypeWitness(sub, sup);
- if (subTypeWitness)
+
+ bool witnessIsOptional = isWitnessUncheckedOptional(subTypeWitness);
+ bool constraintIsOptional = constraintDecl->hasModifier<OptionalConstraintModifier>();
+
+ if (subTypeWitness && (!witnessIsOptional || constraintIsOptional))
{
newArgs.add(subTypeWitness);
}
+ else if (!subTypeWitness && constraintIsOptional)
+ {
+ newArgs.add(m_astBuilder->getOrCreate<NoneWitness>());
+ }
else
{
if (context.mode != OverloadResolveContext::Mode::JustTrying)
diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp
index faf25b716..bf8ddc94e 100644
--- a/source/slang/slang-check-stmt.cpp
+++ b/source/slang/slang-check-stmt.cpp
@@ -522,9 +522,10 @@ void SemanticsStmtVisitor::visitDefaultStmt(DefaultStmt* stmt)
void SemanticsStmtVisitor::visitIfStmt(IfStmt* stmt)
{
+ WithOuterStmt subContext(this, stmt);
stmt->predicate = checkPredicateExpr(stmt->predicate);
- checkStmt(stmt->positiveStatement);
- checkStmt(stmt->negativeStatement);
+ subContext.checkStmt(stmt->positiveStatement);
+ subContext.checkStmt(stmt->negativeStatement);
}
void SemanticsStmtVisitor::visitUnparsedStmt(UnparsedStmt*)
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index d54e1a3e0..384a81f9b 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -845,9 +845,18 @@ DIAGNOSTIC(
DIAGNOSTIC(
30301,
Error,
- isAsOperatorCannotUseInterfaceAsRHS,
- "'is' and 'as' operators do not support interface types as the right-hand side. Use a concrete "
- "type instead.")
+ isOperatorCannotUseInterfaceAsRHS,
+ "cannot use 'is' operator with an interface type as the right-hand "
+ "side without a corresponding optional constraint. Use a concrete type "
+ "instead, or add an optional constraint for the interface type.")
+
+DIAGNOSTIC(
+ 30302,
+ Error,
+ asOperatorCannotUseInterfaceAsRHS,
+ "cannot use 'as' operator with an interface type as the right-hand "
+ "side. Use a concrete type instead. If you want to use an optional "
+ "constraint, use an 'if (T is IInterface)' block instead.")
DIAGNOSTIC(33070, Error, expectedFunction, "expected a function, got '$0'")
@@ -1610,6 +1619,12 @@ DIAGNOSTIC(
Error,
invalidConstraintSubType,
"type '$0' is not a valid left hand side of a type constraint.")
+DIAGNOSTIC(
+ 30403,
+ Error,
+ requiredConstraintIsNotChecked,
+ "the constraint providing '$0' is optional and must be checked with an 'is' statement before "
+ "usage.")
// 305xx: initializer lists
DIAGNOSTIC(30500, Error, tooManyInitializers, "too many initializers (expected $0, got $1)")
diff --git a/source/slang/slang-ir-lower-generic-call.cpp b/source/slang/slang-ir-lower-generic-call.cpp
index 80b258407..41a78d00c 100644
--- a/source/slang/slang-ir-lower-generic-call.cpp
+++ b/source/slang/slang-ir-lower-generic-call.cpp
@@ -295,9 +295,17 @@ struct GenericCallLoweringContext
return;
}
- auto interfaceType = cast<IRInterfaceType>(
+ auto interfaceType = as<IRInterfaceType>(
cast<IRWitnessTableTypeBase>(lookupInst->getWitnessTable()->getDataType())
->getConformanceType());
+
+ if (!interfaceType)
+ {
+ // NoneWitness -> remove call.
+ callInst->removeAndDeallocate();
+ return;
+ }
+
if (isBuiltin(interfaceType))
return;
diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp
index f2b1d1a6a..c003d6125 100644
--- a/source/slang/slang-ir-lower-generic-function.cpp
+++ b/source/slang/slang-ir-lower-generic-function.cpp
@@ -296,8 +296,11 @@ struct GenericFunctionLoweringContext
// and emission of wrapper functions.
void lowerWitnessTable(IRWitnessTable* witnessTable)
{
- auto interfaceType =
- maybeLowerInterfaceType(cast<IRInterfaceType>(witnessTable->getConformanceType()));
+ IRInterfaceType* conformanceType = as<IRInterfaceType>(witnessTable->getConformanceType());
+ if (!conformanceType)
+ return;
+
+ auto interfaceType = maybeLowerInterfaceType(conformanceType);
IRBuilder builderStorage(sharedContext->module);
auto builder = &builderStorage;
builder->setInsertBefore(witnessTable);
@@ -353,8 +356,17 @@ struct GenericFunctionLoweringContext
return;
if (witnessTableType->getConformanceType()->findDecoration<IRComInterfaceDecoration>())
return;
- auto interfaceType =
- maybeLowerInterfaceType(cast<IRInterfaceType>(witnessTableType->getConformanceType()));
+
+ IRInterfaceType* conformanceType =
+ as<IRInterfaceType>(witnessTableType->getConformanceType());
+
+ // NoneWitness generates conformance types which aren't interfaces. In
+ // that case, the method can just be skipped entirely, since there's no
+ // real witness for it and it should be in unreachable code at this
+ // point.
+ if (!conformanceType)
+ return;
+ auto interfaceType = maybeLowerInterfaceType(conformanceType);
interfaceRequirementVal = sharedContext->findInterfaceRequirementVal(
interfaceType,
lookupInst->getRequirementKey());
diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp
index e4d1af93a..0c8b248b0 100644
--- a/source/slang/slang-ir-specialize-dispatch.cpp
+++ b/source/slang/slang-ir-specialize-dispatch.cpp
@@ -232,23 +232,31 @@ void ensureWitnessTableSequentialIDs(SharedGenericsLoweringContext* sharedContex
{
auto interfaceType =
cast<IRWitnessTableType>(inst->getDataType())->getConformanceType();
- auto interfaceLinkage = interfaceType->findDecoration<IRLinkageDecoration>();
- SLANG_ASSERT(
- interfaceLinkage && "An interface type does not have a linkage,"
- "but a witness table associated with it has one.");
- auto interfaceName = interfaceLinkage->getMangledName();
- auto idAllocator =
- linkage->mapInterfaceMangledNameToSequentialIDCounters.tryGetValue(
- interfaceName);
- if (!idAllocator)
+ if (as<IRInterfaceType>(interfaceType))
{
- linkage->mapInterfaceMangledNameToSequentialIDCounters[interfaceName] = 0;
- idAllocator =
+ auto interfaceLinkage = interfaceType->findDecoration<IRLinkageDecoration>();
+ SLANG_ASSERT(
+ interfaceLinkage && "An interface type does not have a linkage,"
+ "but a witness table associated with it has one.");
+ auto interfaceName = interfaceLinkage->getMangledName();
+ auto idAllocator =
linkage->mapInterfaceMangledNameToSequentialIDCounters.tryGetValue(
interfaceName);
+ if (!idAllocator)
+ {
+ linkage->mapInterfaceMangledNameToSequentialIDCounters[interfaceName] = 0;
+ idAllocator =
+ linkage->mapInterfaceMangledNameToSequentialIDCounters.tryGetValue(
+ interfaceName);
+ }
+ seqID = *idAllocator;
+ ++(*idAllocator);
+ }
+ else
+ {
+ // NoneWitness, has special ID of -1.
+ seqID = uint32_t(-1);
}
- seqID = *idAllocator;
- ++(*idAllocator);
linkage->mapMangledNameToRTTIObjectIndex[witnessTableMangledName] = seqID;
}
diff --git a/source/slang/slang-ir-witness-table-wrapper.cpp b/source/slang/slang-ir-witness-table-wrapper.cpp
index fabfd1611..cfda6225e 100644
--- a/source/slang/slang-ir-witness-table-wrapper.cpp
+++ b/source/slang/slang-ir-witness-table-wrapper.cpp
@@ -171,7 +171,9 @@ struct GenerateWitnessTableWrapperContext
void lowerWitnessTable(IRWitnessTable* witnessTable)
{
- auto interfaceType = cast<IRInterfaceType>(witnessTable->getConformanceType());
+ auto interfaceType = as<IRInterfaceType>(witnessTable->getConformanceType());
+ if (!interfaceType)
+ return;
if (isBuiltin(interfaceType))
return;
if (isComInterfaceType(interfaceType))
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index df3507670..8e13a3dc5 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1990,6 +1990,12 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
builder->emitGetTupleElement(elementType, conjunctionWitness, indexInConjunction));
}
+ LoweredValInfo visitNoneWitness(NoneWitness*)
+ {
+ auto builder = getBuilder();
+ auto voidType = builder->getVoidType();
+ return LoweredValInfo::simple(builder->createWitnessTable(voidType, voidType));
+ }
LoweredValInfo visitConstantIntVal(ConstantIntVal* val)
{
@@ -5472,23 +5478,37 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>
LoweredValInfo visitIsTypeExpr(IsTypeExpr* expr)
{
+ auto builder = getBuilder();
if (expr->constantVal)
{
- return LoweredValInfo::simple(getBuilder()->getBoolValue(expr->constantVal->value));
+ return LoweredValInfo::simple(builder->getBoolValue(expr->constantVal->value));
}
- // If expr is a witness, then this is a run-time type check from for an existential type.
if (expr->witnessArg)
{
- auto value = lowerLValueExpr(context, expr->value);
auto type = lowerType(context, expr->typeExpr.type);
auto witness = lowerSimpleVal(context, expr->witnessArg);
- auto existentialInfo = value.getExtractedExistentialValInfo();
- auto irVal = getBuilder()->emitIsType(
- existentialInfo->extractedVal,
- existentialInfo->witnessTable,
- type,
- witness);
- return LoweredValInfo::simple(irVal);
+ auto declWitness = as<DeclaredSubtypeWitness>(expr->witnessArg);
+
+ if (declWitness && declWitness->isOptional())
+ {
+ // Optional constraint check. NoneWitness lowers to a specific
+ // ID, so that we can check for that here.
+ auto witnessID = builder->emitGetSequentialIDInst(witness);
+ auto noneWitnessID = builder->getIntValue(builder->getUIntType(), -1);
+ auto irVal = builder->emitNeq(witnessID, noneWitnessID);
+ return LoweredValInfo::simple(irVal);
+ }
+ else
+ { // This is a run-time type check from for an existential type.
+ auto value = lowerLValueExpr(context, expr->value);
+ auto existentialInfo = value.getExtractedExistentialValInfo();
+ auto irVal = builder->emitIsType(
+ existentialInfo->extractedVal,
+ existentialInfo->witnessTable,
+ type,
+ witness);
+ return LoweredValInfo::simple(irVal);
+ }
}
// For all other cases, we map to a simple type equality check in the IR.
IRType* leftType = nullptr;
@@ -5502,8 +5522,7 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>
}
auto rightType = lowerType(context, expr->typeExpr.type);
IRInst* args[] = {leftType, rightType};
- auto irVal =
- getBuilder()->emitIntrinsicInst(getBuilder()->getBoolType(), kIROp_TypeEquals, 2, args);
+ auto irVal = builder->emitIntrinsicInst(builder->getBoolType(), kIROp_TypeEquals, 2, args);
return LoweredValInfo::simple(irVal);
}
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index a019f97c4..a2fd944eb 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -1692,6 +1692,8 @@ static void maybeParseGenericConstraints(Parser* parser, ContainerDecl* genericP
Token whereToken;
while (AdvanceIf(parser, "where", &whereToken))
{
+ bool optional = AdvanceIf(parser, "optional", &whereToken);
+
auto subType = parser->ParseTypeExp();
if (AdvanceIf(parser, TokenType::Colon))
{
@@ -1702,6 +1704,12 @@ static void maybeParseGenericConstraints(Parser* parser, ContainerDecl* genericP
parser->FillPosition(constraint);
constraint->sub = subType;
constraint->sup = parser->ParseTypeExp();
+ if (optional)
+ {
+ addModifier(
+ constraint,
+ parser->astBuilder->create<OptionalConstraintModifier>());
+ }
AddMember(genericParent, constraint);
if (!AdvanceIf(parser, TokenType::Comma))
break;
@@ -1715,6 +1723,10 @@ static void maybeParseGenericConstraints(Parser* parser, ContainerDecl* genericP
parser->FillPosition(constraint);
constraint->sub = subType;
constraint->sup = parser->ParseTypeExp();
+ if (optional)
+ {
+ addModifier(constraint, parser->astBuilder->create<OptionalConstraintModifier>());
+ }
AddMember(genericParent, constraint);
}
else if (AdvanceIf(parser, TokenType::LParent))
diff --git a/tests/language-feature/generics/where-optional-1.slang b/tests/language-feature/generics/where-optional-1.slang
new file mode 100644
index 000000000..da4bdaacb
--- /dev/null
+++ b/tests/language-feature/generics/where-optional-1.slang
@@ -0,0 +1,11 @@
+//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):
+interface IThing
+{
+ void thing();
+}
+
+void f<T>(T t) where optional T: IThing
+{
+ // Unchecked optional constraint is an error.
+ t.thing(); // CHECK: error 30403
+}
diff --git a/tests/language-feature/generics/where-optional-2.slang b/tests/language-feature/generics/where-optional-2.slang
new file mode 100644
index 000000000..67679bd8d
--- /dev/null
+++ b/tests/language-feature/generics/where-optional-2.slang
@@ -0,0 +1,63 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cpu -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+interface IThing
+{
+ [mutating]
+ int thing(int index);
+}
+
+struct MyThing: IThing
+{
+ int val;
+
+ [mutating]
+ int thing(int index)
+ {
+ val++;
+ outputBuffer[index] = val;
+ return val;
+ }
+}
+
+struct NotMyThing
+{
+ int val;
+}
+
+void f<T>(inout T t, int index) where optional T: IThing
+{
+ if (T is IThing)
+ {
+ outputBuffer[index+1] = 2 * t.thing(index);
+ }
+ else
+ {
+ outputBuffer[index] = 0;
+ outputBuffer[index+1] = 0;
+ }
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(int3 dispatchThreadID: SV_DispatchThreadID)
+{
+ MyThing mt = MyThing(0);
+ NotMyThing nt = NotMyThing(1);
+
+ // CHECK: 1
+ // CHECK-NEXT: 2
+ f<MyThing>(mt, 0);
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: 0
+ f<NotMyThing>(nt, 2);
+ // CHECK: 2
+ // CHECK-NEXT: 4
+ f(mt, 4);
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: 0
+ f(nt, 6);
+}
diff --git a/tests/language-feature/generics/where-optional-3.slang b/tests/language-feature/generics/where-optional-3.slang
new file mode 100644
index 000000000..f8b3a4907
--- /dev/null
+++ b/tests/language-feature/generics/where-optional-3.slang
@@ -0,0 +1,95 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cpu -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+interface IReleaseable
+{
+ [mutating]
+ void release();
+}
+
+struct Container<K, V>
+ where optional K : IReleaseable
+ where optional V : IReleaseable
+{
+ K k[2];
+ V v[2];
+
+ [mutating]
+ void erase(int index)
+ {
+ if (K is IReleaseable)
+ k[index].release();
+ if (V is IReleaseable)
+ v[index].release();
+ }
+}
+
+struct HeavyEntry: IReleaseable
+{
+ int index;
+ int value;
+
+ [mutating]
+ void release()
+ {
+ outputBuffer[index] = value;
+ }
+};
+
+struct LightEntry
+{
+ int value;
+};
+
+[numthreads(1, 1, 1)]
+void computeMain(int3 dispatchThreadID: SV_DispatchThreadID)
+{
+ { // Neither is IReleaseable
+ var c = Container<LightEntry, LightEntry>();
+ c.k[0] = LightEntry(1);
+ c.k[1] = LightEntry(2);
+ c.v[0] = LightEntry(3);
+ c.v[1] = LightEntry(4);
+ c.erase(0);
+ c.erase(1);
+ }
+ { // K is IReleaseable
+ var c = Container<HeavyEntry, LightEntry>();
+ c.k[0] = HeavyEntry(0,1);
+ c.k[1] = HeavyEntry(1,2);
+ c.v[0] = LightEntry(3);
+ c.v[1] = LightEntry(4);
+ // CHECK: 1
+ c.erase(0);
+ // CHECK-NEXT: 2
+ c.erase(1);
+ }
+ { // V is IReleaseable
+ var c = Container<LightEntry, HeavyEntry>();
+ c.k[0] = LightEntry(1);
+ c.k[1] = LightEntry(2);
+ c.v[0] = HeavyEntry(2,3);
+ c.v[1] = HeavyEntry(3,4);
+ // CHECK-NEXT: 3
+ c.erase(0);
+ // CHECK-NEXT: 4
+ c.erase(1);
+ }
+ { // K and V are IReleaseable
+ var c = Container<HeavyEntry, HeavyEntry>();
+ c.k[0] = HeavyEntry(4,5);
+ c.k[1] = HeavyEntry(6,7);
+ c.v[0] = HeavyEntry(5,6);
+ c.v[1] = HeavyEntry(7,8);
+ // CHECK-NEXT: 5
+ // CHECK-NEXT: 6
+ c.erase(0);
+ // CHECK-NEXT: 7
+ // CHECK-NEXT: 8
+ c.erase(1);
+ }
+}
diff --git a/tests/language-feature/generics/where-optional-4.slang b/tests/language-feature/generics/where-optional-4.slang
new file mode 100644
index 000000000..6d72186d9
--- /dev/null
+++ b/tests/language-feature/generics/where-optional-4.slang
@@ -0,0 +1,15 @@
+//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):
+interface IThing
+{
+ void thing();
+}
+
+void g<T>(T t) where T: IThing
+{
+}
+
+void f<T>(T t) where optional T: IThing
+{
+ // Error: cannot upgrade optional to non-optional witness in unchecked context.
+ g<T>(t); // CHECK: error 38029
+}
diff --git a/tests/language-feature/generics/where-optional-5.slang b/tests/language-feature/generics/where-optional-5.slang
new file mode 100644
index 000000000..3ce8041ad
--- /dev/null
+++ b/tests/language-feature/generics/where-optional-5.slang
@@ -0,0 +1,17 @@
+//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):
+interface IThing
+{
+ void thing();
+}
+
+void f<T, U>(T t, U u)
+ where optional T: IThing
+ where optional U: IThing
+{
+ // Error: cannot upgrade optional to non-optional witness in unchecked context.
+ if (U is IThing)
+ {
+ // U being IThing doesn't justify using T as such!
+ t.thing(); // CHECK: error 30403
+ }
+}
diff --git a/tests/language-feature/interface-as-rhs-error.slang b/tests/language-feature/interface-as-rhs-error.slang
index 9ad71afde..7293b5134 100644
--- a/tests/language-feature/interface-as-rhs-error.slang
+++ b/tests/language-feature/interface-as-rhs-error.slang
@@ -21,13 +21,13 @@ struct AnotherType
// These should produce errors - interface types as RHS
bool testIsOperatorWithInterface<T>()
{
- //CHECK: ([[# @LINE+1]]): error 30301: 'is' and 'as' operators do not support interface types as the right-hand side
+ //CHECK: ([[# @LINE+1]]): error 30301: cannot use 'is' operator with an interface type as the right-hand side
return (T is IMyInterface);
}
void testAsOperatorWithInterface<T>(T value)
{
- //CHECK: ([[# @LINE+1]]): error 30301: 'is' and 'as' operators do not support interface types as the right-hand side
+ //CHECK: ([[# @LINE+1]]): error 30302: cannot use 'as' operator with an interface type as the right-hand side
let result = value as IMyInterface;
}