diff options
| author | Yong He <yonghe@outlook.com> | 2022-08-10 14:11:27 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-08-10 14:11:27 -0700 |
| commit | 88f04c29244af23c1cdd472d8d1ae3e5a650494e (patch) | |
| tree | 398e55440e8f7ad157d15b2b75d9887236eaa126 /source | |
| parent | fcdb4629c4c3dd2931eaa88b96b668d914c4519c (diff) | |
`is` and `as` operator and `Optional<T>`. (#2355)
* `is` and `as` operator and `Optional<T>`.
* Fix.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
27 files changed, 846 insertions, 53 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 41f066486..616511b01 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -456,6 +456,25 @@ __intrinsic_type($(kIROp_RefType)) struct Ref {}; +__generic<T> +__magic_type(OptionalType) +__intrinsic_type($(kIROp_OptionalType)) +struct Optional +{ + property bool hasValue + { + __intrinsic_op($(kIROp_OptionalHasValue)) + get; + } + + property T value + { + __intrinsic_op($(kIROp_GetOptionalValue)) + get; + } +}; + + __magic_type(StringType) __intrinsic_type($(kIROp_StringType)) struct String diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index 868763f76..2acdc26d3 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -212,6 +212,13 @@ NodeBase* ASTBuilder::createByNodeType(ASTNodeType nodeType) return (NodeBase*)createFunc(this); } +Type* ASTBuilder::getSpecializedBuiltinType(Type* typeParam, char const* magicTypeName) +{ + auto declRef = getBuiltinDeclRef(magicTypeName, makeConstArrayViewSingle<Val*>(typeParam)); + auto rsType = DeclRefType::create(this, declRef); + return rsType; +} + PtrType* ASTBuilder::getPtrType(Type* valueType) { return dynamicCast<PtrType>(getPtrType(valueType, "PtrType")); @@ -233,23 +240,15 @@ RefType* ASTBuilder::getRefType(Type* valueType) return dynamicCast<RefType>(getPtrType(valueType, "RefType")); } -PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, char const* ptrTypeName) +OptionalType* ASTBuilder::getOptionalType(Type* valueType) { - auto genericDecl = dynamicCast<GenericDecl>(m_sharedASTBuilder->findMagicDecl(ptrTypeName)); - return getPtrType(valueType, genericDecl); + auto rsType = getSpecializedBuiltinType(valueType, "OptionalType"); + return as<OptionalType>(rsType); } -PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, GenericDecl* genericDecl) +PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, char const* ptrTypeName) { - auto typeDecl = genericDecl->inner; - - auto substitutions = create<GenericSubstitution>(); - substitutions->genericDecl = genericDecl; - substitutions->args.add(valueType); - - auto declRef = DeclRef<Decl>(typeDecl, substitutions); - auto rsType = DeclRefType::create(this, declRef); - return as<PtrTypeBase>(rsType); + return as<PtrTypeBase>(getSpecializedBuiltinType(valueType, ptrTypeName)); } ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* elementCount) diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index 3c0303e70..e62e92a7b 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -131,6 +131,8 @@ public: /// Get a builtin type by the BaseType SLANG_FORCE_INLINE Type* getBuiltinType(BaseType flavor) { return m_sharedASTBuilder->m_builtinTypes[Index(flavor)]; } + Type* getSpecializedBuiltinType(Type* typeParam, const char* magicTypeName); + Type* getInitializerListType() { return m_sharedASTBuilder->m_initializerListType; } Type* getOverloadedType() { return m_sharedASTBuilder->m_overloadedType; } Type* getErrorType() { return m_sharedASTBuilder->m_errorType; } @@ -152,14 +154,13 @@ public: // Construct the type `Ref<valueType>` RefType* getRefType(Type* valueType); + // Construct the type `Optional<valueType>` + OptionalType* getOptionalType(Type* valueType); + // Construct a pointer type like `Ptr<valueType>`, but where // the actual type name for the pointer type is given by `ptrTypeName` PtrTypeBase* getPtrType(Type* valueType, char const* ptrTypeName); - // Construct a pointer type like `Ptr<valueType>`, but where - // the generic declaration for the pointer type is `genericDecl` - PtrTypeBase* getPtrType(Type* valueType, GenericDecl* genericDecl); - ArrayExpressionType* getArrayType(Type* elementType, IntVal* elementCount); VectorExpressionType* getVectorType(Type* elementType, IntVal* elementCount); diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index 4a9cc475d..2dfd937a4 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -299,6 +299,44 @@ class CastToSuperTypeExpr: public Expr Val* witnessArg = nullptr; }; + /// A `value is Type` expression that evaluates to `true` if type of `value` is a sub-type of + /// `Type`. +class IsTypeExpr : public Expr +{ + SLANG_AST_CLASS(IsTypeExpr) + + Expr* value = nullptr; + TypeExp typeExpr; + + // A witness showing that `typeExpr.type` is a subtype of `typeof(value)`. + Val* witnessArg = nullptr; + + bool isAlwaysTrue = false; +}; + + /// A `value as Type` expression that casts `value` to `Type` within type hierarchy. + /// The result is undefined if `value` is not `Type`. +class AsTypeExpr : public Expr +{ + SLANG_AST_CLASS(AsTypeExpr) + + Expr* value = nullptr; + Expr* typeExpr = nullptr; + + // A witness showing that `typeExpr` is a subtype of `typeof(value)`. + Val* witnessArg = nullptr; + +}; + +class MakeOptionalExpr : public Expr +{ + SLANG_AST_CLASS(MakeOptionalExpr) + + // If `value` is null, this constructs an `Optional<T>` that doesn't have a value. + Expr* value = nullptr; + Expr* typeExpr = nullptr; +}; + /// A cast of a value to the same type, with different modifiers. /// /// The type being cast to is stored as this expression's `type`. diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index d439420a9..8461ff7a3 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -235,6 +235,24 @@ struct ASTIterator iterator->maybeDispatchCallback(expr); dispatchIfNotNull(expr->base.exp); } + void visitAsTypeExpr(AsTypeExpr* expr) + { + iterator->maybeDispatchCallback(expr); + dispatchIfNotNull(expr->value); + dispatchIfNotNull(expr->typeExpr); + } + void visitIsTypeExpr(IsTypeExpr* expr) + { + iterator->maybeDispatchCallback(expr); + dispatchIfNotNull(expr->value); + dispatchIfNotNull(expr->typeExpr.exp); + } + void visitMakeOptionalExpr(MakeOptionalExpr* expr) + { + iterator->maybeDispatchCallback(expr); + dispatchIfNotNull(expr->value); + dispatchIfNotNull(expr->typeExpr); + } }; struct ASTIteratorStmtVisitor : public StmtVisitor<ASTIteratorStmtVisitor> diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 43fe751ee..664c940a8 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -517,6 +517,11 @@ Type* PtrTypeBase::getValueType() return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]); } +Type* OptionalType::getValueType() +{ + return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]); +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! NamedExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void NamedExpressionType::_toTextOverride(StringBuilder& out) diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 5d4e42bfb..7b94cbe6d 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -572,6 +572,12 @@ class RefType : public ParamDirectionType SLANG_AST_CLASS(RefType) }; +class OptionalType : public BuiltinType +{ + SLANG_AST_CLASS(OptionalType) + Type* getValueType(); +}; + // A type alias of some kind (e.g., via `typedef`) class NamedExpressionType : public Type { diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index e77ea4981..5889c6140 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -390,6 +390,28 @@ namespace Slang return _isDeclaredSubtype(subType, subType, superTypeDeclRef, nullptr, nullptr); } + bool SemanticsVisitor::isDeclaredSubtype( + Type* subType, + Type* superType) + { + if (auto declRefType = as<DeclRefType>(superType)) + { + if (auto aggTypeDeclRef = declRefType->declRef.as<AggTypeDecl>()) + return _isDeclaredSubtype(subType, subType, aggTypeDeclRef, nullptr, nullptr); + } + return false; + } + + bool SemanticsVisitor::isInterfaceType(Type* type) + { + if (auto declRefType = as<DeclRefType>(type)) + { + if (auto interfaceDeclRef = declRefType->declRef.as<InterfaceDecl>()) + return true; + } + return false; + } + Val* SemanticsVisitor::tryGetSubtypeWitness( Type* subType, DeclRef<AggTypeDecl> superTypeDeclRef) diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index a787af211..714eba9a3 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1753,6 +1753,90 @@ namespace Slang return expr; } + Expr* SemanticsExprVisitor::visitIsTypeExpr(IsTypeExpr* expr) + { + expr->typeExpr = CheckProperType(expr->typeExpr); + auto originalVal = CheckTerm(expr->value); + expr->type = m_astBuilder->getBoolType(); + expr->value = originalVal; + + // If value is a subtype of `type`, then this expr is always true. + if (isDeclaredSubtype(expr->value->type.type, expr->typeExpr.type)) + { + // Instead of returning a BoolLiteralExpr, we use a field to indicate this scenario, + // so that the language server can still see the original syntax tree. + expr->isAlwaysTrue = true; + return expr; + } + + // Otherwise, we need to ensure the target type is a subtype of value->type. + // For now we can only support the scenario where `expr->value` is an interface type. + if (!isInterfaceType(originalVal->type)) + { + getSink()->diagnose(expr, Diagnostics::isOperatorValueMustBeInterfaceType); + } + + expr->value = maybeOpenExistential(originalVal); + expr->witnessArg = tryGetSubtypeWitness(expr->typeExpr.type, originalVal->type.type); + if (expr->witnessArg) + { + return expr; + } + + if (!as<ErrorType>(expr->typeExpr.type) && !as<ErrorType>(expr->value->type.type)) + { + getSink()->diagnose(expr, Diagnostics::typeNotInTheSameHierarchy, expr->value->type.type, expr->typeExpr.type); + } + + expr->type = m_astBuilder->getErrorType(); + return expr; + } + + Expr* SemanticsExprVisitor::visitAsTypeExpr(AsTypeExpr* expr) + { + TypeExp typeExpr; + typeExpr.exp = expr->typeExpr; + typeExpr = CheckProperType(typeExpr); + expr->value = CheckTerm(expr->value); + auto optType = m_astBuilder->getOptionalType(typeExpr.type); + expr->type = optType; + + // If value is a subtype of `type`, then this expr is equivalent to a CastToSuperTypeExpr. + if (auto witness = tryGetSubtypeWitness(expr->value->type.type, typeExpr.type)) + { + auto castToSuperType = createCastToSuperTypeExpr(typeExpr.type, expr->value, witness); + auto makeOptional = m_astBuilder->create<MakeOptionalExpr>(); + makeOptional->loc = expr->loc; + makeOptional->type = optType; + makeOptional->value = castToSuperType; + makeOptional->typeExpr = typeExpr.exp; + return makeOptional; + } + + // For now we can only support the scenario where `expr->value` is an interface type. + if (!isInterfaceType(expr->value->type)) + { + getSink()->diagnose(expr, Diagnostics::isOperatorValueMustBeInterfaceType); + } + + expr->typeExpr = typeExpr.exp; + expr->witnessArg = tryGetSubtypeWitness(typeExpr.type, expr->value->type.type); + if (expr->witnessArg) + { + expr->value = maybeOpenExistential(expr->value); + return expr; + } + + if (!as<ErrorType>(typeExpr.type) && !as<ErrorType>(expr->value->type.type)) + { + getSink()->diagnose(expr, Diagnostics::typeNotInTheSameHierarchy, expr->value->type.type, typeExpr.type); + } + + expr->type = m_astBuilder->getErrorType(); + + return expr; + } + Expr* SemanticsVisitor::MaybeDereference(Expr* inExpr) { Expr* expr = inExpr; diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index f2f7a6bd1..df713d80f 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1278,6 +1278,13 @@ namespace Slang Type* subType, DeclRef<AggTypeDecl> superTypeDeclRef); + /// Check whether `subType` is a sub-type of `supType`. + bool isDeclaredSubtype( + Type* subType, + Type* supType); + + bool isInterfaceType(Type* type); + /// Check whether `subType` is a sub-type of `superTypeDeclRef`, /// and return a witness to the sub-type relationship if it holds /// (return null otherwise). @@ -1714,6 +1721,10 @@ namespace Slang Expr* visitTryExpr(TryExpr* expr); + Expr* visitIsTypeExpr(IsTypeExpr* expr); + + Expr* visitAsTypeExpr(AsTypeExpr* expr); + // // Some syntax nodes should not occur in the concrete input syntax, // and will only appear *after* checking is complete. We need to @@ -1740,6 +1751,7 @@ namespace Slang CASE(LetExpr) CASE(ExtractExistentialValueExpr) CASE(OpenRefExpr) + CASE(MakeOptionalExpr) #undef CASE diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 2b12c3de4..54d81da7d 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -251,6 +251,7 @@ DIAGNOSTIC(30013, Error, subscriptNonArray, "no subscript operation found for t DIAGNOSTIC(30014, Error, subscriptIndexNonInteger, "index expression must evaluate to int.") DIAGNOSTIC(30015, Error, undefinedIdentifier2, "undefined identifier '$0'.") DIAGNOSTIC(30017, Error, componentNotAccessibleFromShader, "component '$0' is not accessible from shader '$1'.") +DIAGNOSTIC(30018, Error, typeNotInTheSameHierarchy, "as/is operator requires '$0' and '$1' to be in the same type hierarchy.") DIAGNOSTIC(30019, Error, typeMismatch, "expected an expression of type '$0', got '$1'") DIAGNOSTIC(30020, Error, importOperatorReturnTypeMismatch, "import operator should return '$1', but the expression has type '$0''. do you forget 'project'?") DIAGNOSTIC(30021, Error, noApplicationFunction, "$0: no overload takes arguments ($1)") @@ -286,6 +287,7 @@ DIAGNOSTIC(30200, Error, redeclaration, "declaration of '$0' conflicts with exis DIAGNOSTIC(30201, Error, functionRedefinition, "function '$0' already has a body") DIAGNOSTIC(30202, Error, functionRedeclarationWithDifferentReturnType, "function '$0' declared to return '$1' was previously declared to return '$2'") +DIAGNOSTIC(30300, Error, isOperatorValueMustBeInterfaceType, "'is'/'as' operator requires an interface-typed expression.") DIAGNOSTIC(33070, Error, expectedFunction, "expected a function, got '$0'") DIAGNOSTIC(33071, Error, expectedAStringLiteral, "expected a string literal") diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 21073f4c8..da8739a49 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -25,6 +25,7 @@ #include "slang-ir-lower-generics.h" #include "slang-ir-lower-tuple-types.h" #include "slang-ir-lower-result-type.h" +#include "slang-ir-lower-optional-type.h" #include "slang-ir-lower-bit-cast.h" #include "slang-ir-lower-reinterpret.h" #include "slang-ir-metadata.h" @@ -387,6 +388,9 @@ Result linkAndOptimizeIR( // will run a DCE pass to clean up after the specialization. // simplifyIR(irModule); + + lowerOptionalType(irModule, sink); + #if 0 dumpIRIfEnabled(codeGenContext, irModule, "AFTER DCE"); #endif diff --git a/source/slang/slang-ir-any-value-marshalling.cpp b/source/slang/slang-ir-any-value-marshalling.cpp index 5c91766f7..39292e2b1 100644 --- a/source/slang/slang-ir-any-value-marshalling.cpp +++ b/source/slang/slang-ir-any-value-marshalling.cpp @@ -762,6 +762,13 @@ namespace Slang size += kRTTIHeaderSize; return alignUp(offset, 4) + alignUp((SlangInt)size, 4); } + case kIROp_ExtractExistentialType: + { + auto existentialValue = type->getOperand(0); + auto interfaceType = cast<IRInterfaceType>(existentialValue->getDataType()); + auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc); + return alignUp(offset, 4) + alignUp((SlangInt)size, 4); + } default: if (as<IRTextureTypeBase>(type) || as<IRSamplerStateTypeBase>(type)) { diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index aeb6d4ea1..978317ccd 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -57,6 +57,7 @@ INST(Nop, nop, 0, 0) INST(ConjunctionType, Conjunction, 0, 0) INST(AttributedType, Attributed, 0, 0) INST(ResultType, Result, 2, 0) + INST(OptionalType, Optional, 1, 0) INST(DifferentialPairType, DiffPair, 1, 0) @@ -292,7 +293,10 @@ INST(MakeResultError, makeResultError, 1, 0) INST(IsResultError, isResultError, 1, 0) INST(GetResultError, getResultError, 1, 0) INST(GetResultValue, getResultValue, 1, 0) - +INST(GetOptionalValue, getOptionalValue, 1, 0) +INST(OptionalHasValue, optionalHasValue, 1, 0) +INST(MakeOptionalValue, makeOptionalValue, 1, 0) +INST(MakeOptionalNone, makeOptionalNone, 1, 0) INST(Call, call, 1, 0) INST(RTTIObject, rtti_object, 0, 0) @@ -745,6 +749,7 @@ INST(ExtractTaggedUnionPayload, extractTaggedUnionPayload, 1, 0) INST(BitCast, bitCast, 1, 0) INST(Reinterpret, reinterpret, 1, 0) INST(CastPtrToBool, CastPtrToBool, 1, 0) +INST(IsType, IsType, 3, 0) INST(JVPDifferentiate, jvpDifferentiate, 1, 0) // Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 2e2dbed5a..7f4e991b0 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1902,6 +1902,33 @@ struct IRGetResultError : IRInst IRInst* getResultOperand() { return getOperand(0); } }; +struct IROptionalHasValue : IRInst +{ + IR_LEAF_ISA(OptionalHasValue) + + IRInst* getOptionalOperand() { return getOperand(0); } +}; + +struct IRGetOptionalValue : IRInst +{ + IR_LEAF_ISA(GetOptionalValue) + + IRInst* getOptionalOperand() { return getOperand(0); } +}; + +struct IRMakeOptionalValue : IRInst +{ + IR_LEAF_ISA(MakeOptionalValue) + + IRInst* getValue() { return getOperand(0); } +}; + +struct IRMakeOptionalNone : IRInst +{ + IR_LEAF_ISA(MakeOptionalNone) + IRInst* getDefaultValue() { return getOperand(0); } +}; + /// An instruction that packs a concrete value into an existential-type "box" struct IRMakeExistential : IRInst { @@ -1985,6 +2012,17 @@ struct IRLiveRangeStart : IRLiveRangeMarker IR_LEAF_ISA(LiveRangeStart); }; +struct IRIsType : IRInst +{ + IR_LEAF_ISA(IsType); + + IRInst* getValue() { return getOperand(0); } + IRInst* getValueWitness() { return getOperand(1); } + + IRInst* getTypeOperand() { return getOperand(2); } + IRInst* getTargetWitness() { return getOperand(3); } +}; + /// Demarks where the referenced item is no longer live, optimimally (although not /// necessarily) at the previous instruction. /// @@ -2256,6 +2294,7 @@ public: IRTupleType* getTupleType(IRType* type0, IRType* type1, IRType* type2, IRType* type3); IRResultType* getResultType(IRType* valueType, IRType* errorType); + IROptionalType* getOptionalType(IRType* valueType); IRBasicBlockType* getBasicBlockType(); IRWitnessTableType* getWitnessTableType(IRType* baseType); @@ -2504,7 +2543,10 @@ public: IRInst* emitIsResultError(IRInst* result); IRInst* emitGetResultError(IRInst* result); IRInst* emitGetResultValue(IRInst* result); - + IRInst* emitOptionalHasValue(IRInst* optValue); + IRInst* emitGetOptionalValue(IRInst* optValue); + IRInst* emitMakeOptionalValue(IRInst* optType, IRInst* value); + IRInst* emitMakeOptionalNone(IRInst* optType, IRInst* defaultValue); IRInst* emitMakeVector( IRType* type, UInt argCount, @@ -2581,6 +2623,8 @@ public: IRUndefined* emitUndefined(IRType* type); + IRInst* emitReinterpret(IRInst* type, IRInst* value); + IRInst* findOrAddInst( IRType* type, IROp op, @@ -2717,6 +2761,8 @@ public: IRInst* coord, IRInst* value); + IRInst* emitIsType(IRInst* value, IRInst* witness, IRInst* typeOperand, IRInst* targetWitness); + IRInst* emitFieldExtract( IRType* type, IRInst* base, diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp index 3b9a17738..c95f6976c 100644 --- a/source/slang/slang-ir-lower-generics.cpp +++ b/source/slang/slang-ir-lower-generics.cpp @@ -9,11 +9,12 @@ #include "slang-ir-lower-generic-function.h" #include "slang-ir-lower-generic-call.h" #include "slang-ir-lower-generic-type.h" +#include "slang-ir-inst-pass-base.h" #include "slang-ir-specialize-dispatch.h" #include "slang-ir-specialize-dynamic-associatedtype-lookup.h" #include "slang-ir-witness-table-wrapper.h" -#include "slang-ir-ssa.h" -#include "slang-ir-dce.h" +#include "slang-ir-ssa-simplification.h" + namespace Slang { @@ -93,6 +94,23 @@ namespace Slang inst->removeAndDeallocate(); } } + + void lowerIsTypeInsts(SharedGenericsLoweringContext* sharedContext) + { + InstPassBase pass(sharedContext->module); + pass.processInstsOfType<IRIsType>(kIROp_IsType, [&](IRIsType* inst) + { + auto witnessTableType = as<IRWitnessTableTypeBase>(inst->getValueWitness()->getDataType()); + if (witnessTableType && isComInterfaceType((IRType*)witnessTableType->getConformanceType())) + return; + IRBuilder builder(sharedContext->sharedBuilderStorage); + builder.setInsertBefore(inst); + auto eqlInst = builder.emitEql(builder.emitGetSequentialIDInst(inst->getValueWitness()), + builder.emitGetSequentialIDInst(inst->getTargetWitness())); + inst->replaceUsesWith(eqlInst); + inst->removeAndDeallocate(); + }); + } // Turn all references of witness table or RTTI objects into integer IDs, generate // specialized `switch` based dispatch functions based on witness table IDs, and remove @@ -105,6 +123,8 @@ namespace Slang if (sink->getErrorCount() != 0) return; + lowerIsTypeInsts(sharedContext); + specializeDynamicAssociatedTypeLookup(sharedContext); if (sink->getErrorCount() != 0) return; @@ -112,6 +132,7 @@ namespace Slang sharedContext->sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); sharedContext->mapInterfaceRequirementKeyValue.Clear(); + specializeRTTIObjectReferences(sharedContext); cleanUpRTTIHandleTypes(sharedContext); @@ -175,6 +196,7 @@ namespace Slang // and used to create a tuple representing the existential value. augmentMakeExistentialInsts(module); + lowerGenericFunctions(&sharedContext); if (sink->getErrorCount() != 0) return; @@ -200,6 +222,8 @@ namespace Slang // real RTTI objects and witness tables. specializeRTTIObjects(&sharedContext, sink); + simplifyIR(module); + lowerTuples(module, sink); if (sink->getErrorCount() != 0) return; @@ -210,7 +234,6 @@ namespace Slang // We might have generated new temporary variables during lowering. // An SSA pass can clean up unnecessary load/stores. - constructSSA(module); - eliminateDeadCode(module); + simplifyIR(module); } } // namespace Slang diff --git a/source/slang/slang-ir-lower-optional-type.cpp b/source/slang/slang-ir-lower-optional-type.cpp new file mode 100644 index 000000000..79be4e042 --- /dev/null +++ b/source/slang/slang-ir-lower-optional-type.cpp @@ -0,0 +1,239 @@ +// slang-ir-lower-optional-type.cpp + +#include "slang-ir-lower-optional-type.h" +#include "slang-ir.h" +#include "slang-ir-insts.h" + +namespace Slang +{ + struct OptionalTypeLoweringContext + { + IRModule* module; + DiagnosticSink* sink; + + SharedIRBuilder sharedBuilderStorage; + + List<IRInst*> workList; + HashSet<IRInst*> workListSet; + + struct LoweredOptionalTypeInfo : public RefObject + { + IRType* optionalType = nullptr; + IRType* valueType = nullptr; + IRType* loweredType = nullptr; + IRStructField* valueField = nullptr; + IRStructField* hasValueField = nullptr; + }; + Dictionary<IRInst*, RefPtr<LoweredOptionalTypeInfo>> mapLoweredTypeToOptionalTypeInfo; + Dictionary<IRInst*, RefPtr<LoweredOptionalTypeInfo>> loweredOptionalTypes; + + IRType* maybeLowerOptionalType(IRBuilder* builder, IRType* type) + { + if (auto info = getLoweredOptionalType(builder, type)) + return info->loweredType; + else + return type; + } + + LoweredOptionalTypeInfo* getLoweredOptionalType(IRBuilder* builder, IRInst* type) + { + if (auto loweredInfo = loweredOptionalTypes.TryGetValue(type)) + return loweredInfo->Ptr(); + if (auto loweredInfo = mapLoweredTypeToOptionalTypeInfo.TryGetValue(type)) + return loweredInfo->Ptr(); + + if (!type) + return nullptr; + if (type->getOp() != kIROp_OptionalType) + return nullptr; + + RefPtr<LoweredOptionalTypeInfo> info = new LoweredOptionalTypeInfo(); + info->optionalType = (IRType*)type; + auto optionalType = cast<IROptionalType>(type); + auto valueType = optionalType->getValueType(); + info->valueType = valueType; + + auto structType = builder->createStructType(); + info->loweredType = structType; + builder->addNameHintDecoration(structType, UnownedStringSlice("OptionalType")); + + info->valueType = valueType; + auto valueKey = builder->createStructKey(); + builder->addNameHintDecoration(valueKey, UnownedStringSlice("value")); + info->valueField = builder->createStructField(structType, valueKey, (IRType*)valueType); + + auto boolType = builder->getBoolType(); + auto hasValueKey = builder->createStructKey(); + builder->addNameHintDecoration(hasValueKey, UnownedStringSlice("hasValue")); + info->hasValueField = builder->createStructField(structType, hasValueKey, (IRType*)boolType); + + mapLoweredTypeToOptionalTypeInfo[info->loweredType] = info; + loweredOptionalTypes[type] = info; + return info.Ptr(); + } + + void addToWorkList( + IRInst* inst) + { + for (auto ii = inst->getParent(); ii; ii = ii->getParent()) + { + if (as<IRGeneric>(ii)) + return; + } + + if (workListSet.Contains(inst)) + return; + + workList.add(inst); + workListSet.Add(inst); + } + + void processMakeOptionalValue(IRMakeOptionalValue* inst) + { + IRBuilder builderStorage(sharedBuilderStorage); + auto builder = &builderStorage; + builder->setInsertBefore(inst); + + auto info = getLoweredOptionalType(builder, inst->getDataType()); + List<IRInst*> operands; + operands.add(inst->getOperand(0)); + operands.add(builder->getBoolValue(true)); + auto makeStruct = builder->emitMakeStruct(info->loweredType, operands); + inst->replaceUsesWith(makeStruct); + inst->removeAndDeallocate(); + } + + void processMakeOptionalNone(IRMakeOptionalNone* inst) + { + IRBuilder builderStorage(sharedBuilderStorage); + auto builder = &builderStorage; + builder->setInsertBefore(inst); + + auto info = getLoweredOptionalType(builder, inst->getDataType()); + + List<IRInst*> operands; + operands.add(inst->getDefaultValue()); + operands.add(builder->getBoolValue(false)); + auto makeStruct = builder->emitMakeStruct(info->loweredType, operands); + inst->replaceUsesWith(makeStruct); + inst->removeAndDeallocate(); + } + + IRInst* getOptionalHasValue(IRBuilder* builder, IRInst* optionalInst) + { + auto loweredOptionalTypeInfo = getLoweredOptionalType(builder, optionalInst->getDataType()); + SLANG_ASSERT(loweredOptionalTypeInfo); + + auto value = builder->emitFieldExtract( + builder->getBoolType(), + optionalInst, + loweredOptionalTypeInfo->hasValueField->getKey()); + return value; + } + + void processGetOptionalHasValue(IROptionalHasValue* inst) + { + IRBuilder builderStorage(sharedBuilderStorage); + auto builder = &builderStorage; + builder->setInsertBefore(inst); + + auto optionalValue = inst->getOptionalOperand(); + auto hasVal = getOptionalHasValue(builder, optionalValue); + inst->replaceUsesWith(hasVal); + inst->removeAndDeallocate(); + } + + void processGetOptionalValue(IRGetOptionalValue* inst) + { + IRBuilder builderStorage(sharedBuilderStorage); + auto builder = &builderStorage; + builder->setInsertBefore(inst); + + auto base = inst->getOptionalOperand(); + auto loweredOptionalTypeInfo = getLoweredOptionalType(builder, base->getDataType()); + SLANG_ASSERT(loweredOptionalTypeInfo); + SLANG_ASSERT(loweredOptionalTypeInfo->valueField); + auto getElement = builder->emitFieldExtract( + loweredOptionalTypeInfo->valueType, + base, + loweredOptionalTypeInfo->valueField->getKey()); + inst->replaceUsesWith(getElement); + inst->removeAndDeallocate(); + } + + void processOptionalType(IROptionalType* inst) + { + IRBuilder builderStorage(sharedBuilderStorage); + auto builder = &builderStorage; + builder->setInsertBefore(inst); + + auto loweredOptionalTypeInfo = getLoweredOptionalType(builder, inst); + SLANG_ASSERT(loweredOptionalTypeInfo); + SLANG_UNUSED(loweredOptionalTypeInfo); + } + + void processInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_MakeOptionalValue: + processMakeOptionalValue((IRMakeOptionalValue*)inst); + break; + case kIROp_MakeOptionalNone: + processMakeOptionalNone((IRMakeOptionalNone*)inst); + break; + case kIROp_OptionalHasValue: + processGetOptionalHasValue((IROptionalHasValue*)inst); + break; + case kIROp_GetOptionalValue: + processGetOptionalValue((IRGetOptionalValue*)inst); + break; + case kIROp_OptionalType: + processOptionalType((IROptionalType*)inst); + break; + default: + break; + } + } + + void processModule() + { + SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; + sharedBuilder->init(module); + + // Deduplicate equivalent types. + sharedBuilder->deduplicateAndRebuildGlobalNumberingMap(); + + addToWorkList(module->getModuleInst()); + + while (workList.getCount() != 0) + { + IRInst* inst = workList.getLast(); + + workList.removeLast(); + workListSet.Remove(inst); + + processInst(inst); + + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + addToWorkList(child); + } + } + + // Replace all optional types with lowered struct types. + for (auto kv : loweredOptionalTypes) + { + kv.Key->replaceUsesWith(kv.Value->loweredType); + } + } + }; + + void lowerOptionalType(IRModule* module, DiagnosticSink* sink) + { + OptionalTypeLoweringContext context; + context.module = module; + context.sink = sink; + context.processModule(); + } +} diff --git a/source/slang/slang-ir-lower-optional-type.h b/source/slang/slang-ir-lower-optional-type.h new file mode 100644 index 000000000..1a011da26 --- /dev/null +++ b/source/slang/slang-ir-lower-optional-type.h @@ -0,0 +1,16 @@ +// slang-ir-lower-optional-type.h +#pragma once + +#include "slang-ir.h" + +namespace Slang +{ + struct IRModule; + class DiagnosticSink; + + /// Lower `IROptionalType<T,E>` types to ordinary `struct`s. + void lowerOptionalType( + IRModule* module, + DiagnosticSink* sink); + +} diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index ffdb84c4a..d67c87ca9 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -96,6 +96,87 @@ struct PeepholeContext : InstPassBase changed = true; } break; + case kIROp_IsType: + { + auto isTypeInst = as<IRIsType>(inst); + auto actualType = isTypeInst->getValue()->getDataType(); + if (isTypeEqual(actualType, (IRType*)isTypeInst->getTypeOperand())) + { + IRBuilder builder(&sharedBuilderStorage); + builder.setInsertBefore(inst); + auto trueVal = builder.getBoolValue(true); + inst->replaceUsesWith(trueVal); + inst->removeAndDeallocate(); + changed = true; + } + } + break; + case kIROp_Reinterpret: + { + if (isTypeEqual(inst->getOperand(0)->getDataType(), inst->getDataType())) + { + inst->replaceUsesWith(inst->getOperand(0)); + inst->removeAndDeallocate(); + changed = true; + } + } + break; + case kIROp_UnpackAnyValue: + { + if (inst->getOperand(0)->getOp() == kIROp_PackAnyValue) + { + if (isTypeEqual(inst->getOperand(0)->getOperand(0)->getDataType(), inst->getDataType())) + { + inst->replaceUsesWith(inst->getOperand(0)->getOperand(0)); + inst->removeAndDeallocate(); + changed = true; + } + } + } + break; + case kIROp_PackAnyValue: + { + // Pack(obj: anyValueN) : anyValueN --> obj + if (isTypeEqual(inst->getOperand(0)->getDataType(), inst->getDataType())) + { + inst->replaceUsesWith(inst->getOperand(0)); + inst->removeAndDeallocate(); + changed = true; + } + } + break; + case kIROp_GetOptionalValue: + { + if (inst->getOperand(0)->getOp() == kIROp_MakeOptionalValue) + { + inst->replaceUsesWith(inst->getOperand(0)->getOperand(0)); + inst->removeAndDeallocate(); + changed = true; + } + } + break; + case kIROp_OptionalHasValue: + { + if (inst->getOperand(0)->getOp() == kIROp_MakeOptionalValue) + { + IRBuilder builder(&sharedBuilderStorage); + builder.setInsertBefore(inst); + auto trueVal = builder.getBoolValue(true); + inst->replaceUsesWith(trueVal); + inst->removeAndDeallocate(); + changed = true; + } + else if (inst->getOperand(0)->getOp() == kIROp_MakeOptionalNone) + { + IRBuilder builder(&sharedBuilderStorage); + builder.setInsertBefore(inst); + auto falseVal = builder.getBoolValue(false); + inst->replaceUsesWith(falseVal); + inst->removeAndDeallocate(); + changed = true; + } + } + break; default: break; } @@ -105,6 +186,7 @@ struct PeepholeContext : InstPassBase { SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; sharedBuilder->init(module); + sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); changed = false; processAllInsts([this](IRInst* inst) { processInst(inst); }); diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp index 450867abb..0ca0933f1 100644 --- a/source/slang/slang-ir-specialize-dispatch.cpp +++ b/source/slang/slang-ir-specialize-dispatch.cpp @@ -253,8 +253,6 @@ void fixupDispatchFuncCall(SharedGenericsLoweringContext* sharedContext, IRFunc* } if (as<IRWitnessTable>(args[0]->getDataType())) continue; - auto seqIdArg = builder.emitGetSequentialIDInst(args[0]); - args[0] = seqIdArg; auto newCall = builder.emitCallInst(call->getFullType(), newDispatchFunc, args); call->replaceUsesWith(newCall); call->removeAndDeallocate(); diff --git a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp index 52c8edca6..dbcc7a54c 100644 --- a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp +++ b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp @@ -168,25 +168,15 @@ struct AssociatedTypeLookupSpecializationContext void processGetSequentialIDInst(IRGetSequentialID* inst) { - if (inst->getRTTIOperand()->getDataType()->getOp() == kIROp_WitnessTableIDType) - { - // If the operand is a witness table id, just return the operand. - inst->replaceUsesWith(inst->getRTTIOperand()); - inst->removeAndDeallocate(); - } - else if (inst->getRTTIOperand()->getDataType()->getOp() == kIROp_VectorType) - { - // If the operand is a witness table, it is already replaced with a uint2 - // at this point, where the first element in the uint2 is the id of the - // witness table. - auto vectorType = inst->getRTTIOperand()->getDataType(); - IRBuilder builder(sharedContext->sharedBuilderStorage); - builder.setInsertBefore(inst); - UInt index = 0; - auto id = builder.emitSwizzle(as<IRVectorType>(vectorType)->getElementType(), inst->getRTTIOperand(), 1, &index); - inst->replaceUsesWith(id); - inst->removeAndDeallocate(); - } + // If the operand is a witness table, it is already replaced with a uint2 + // at this point, where the first element in the uint2 is the id of the + // witness table. + IRBuilder builder(sharedContext->sharedBuilderStorage); + builder.setInsertBefore(inst); + UInt index = 0; + auto id = builder.emitSwizzle(builder.getUIntType(), inst->getRTTIOperand(), 1, &index); + inst->replaceUsesWith(id); + inst->removeAndDeallocate(); } void processModule() diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index c66f0d555..fd7cbe408 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2587,7 +2587,7 @@ namespace Slang IRAnyValueType* IRBuilder::getAnyValueType(IRIntegerValue size) { return (IRAnyValueType*)getType(kIROp_AnyValueType, - getIntValue(getIntType(), size)); + getIntValue(getUIntType(), size)); } IRAnyValueType* IRBuilder::getAnyValueType(IRInst* size) @@ -2624,6 +2624,11 @@ namespace Slang return (IRResultType*)getType(kIROp_ResultType, 2, operands); } + IROptionalType* IRBuilder::getOptionalType(IRType* valueType) + { + return (IROptionalType*)getType(kIROp_OptionalType, valueType); + } + IRBasicBlockType* IRBuilder::getBasicBlockType() { return (IRBasicBlockType*)getType(kIROp_BasicBlockType); @@ -2971,6 +2976,11 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitReinterpret(IRInst* type, IRInst* value) + { + return emitIntrinsicInst((IRType*)type, kIROp_Reinterpret, 1, &value); + } + IRLiveRangeStart* IRBuilder::emitLiveRangeStart(IRInst* referenced) { // This instruction doesn't produce any result, @@ -3323,6 +3333,42 @@ namespace Slang &result); } + IRInst* IRBuilder::emitOptionalHasValue(IRInst* optValue) + { + return emitIntrinsicInst( + getBoolType(), + kIROp_OptionalHasValue, + 1, + &optValue); + } + + IRInst* IRBuilder::emitGetOptionalValue(IRInst* optValue) + { + return emitIntrinsicInst( + cast<IROptionalType>(optValue->getDataType())->getValueType(), + kIROp_GetOptionalValue, + 1, + &optValue); + } + + IRInst* IRBuilder::emitMakeOptionalValue(IRInst* optType, IRInst* value) + { + return emitIntrinsicInst( + (IRType*)optType, + kIROp_MakeOptionalValue, + 1, + &value); + } + + IRInst* IRBuilder::emitMakeOptionalNone(IRInst* optType, IRInst* defaultValue) + { + return emitIntrinsicInst( + (IRType*)optType, + kIROp_MakeOptionalNone, + 1, + &defaultValue); + } + IRInst* IRBuilder::emitMakeVector( IRType* type, UInt argCount, @@ -3819,6 +3865,14 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitIsType(IRInst* value, IRInst* witness, IRInst* typeOperand, IRInst* targetWitness) + { + IRInst* args[] = { value, witness, typeOperand, targetWitness }; + auto inst = createInst<IRIsType>(this, kIROp_IsType, getBoolType(), SLANG_COUNT_OF(args), args); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitFieldExtract( IRType* type, IRInst* base, @@ -6079,6 +6133,10 @@ namespace Slang case kIROp_GetResultError: case kIROp_GetResultValue: case kIROp_IsResultError: + case kIROp_MakeOptionalValue: + case kIROp_MakeOptionalNone: + case kIROp_OptionalHasValue: + case kIROp_GetOptionalValue: case kIROp_Load: // We are ignoring the possibility of loads from bad addresses, or `volatile` loads case kIROp_ImageSubscript: case kIROp_FieldExtract: @@ -6118,6 +6176,9 @@ namespace Slang case kIROp_WrapExistential: case kIROp_BitCast: case kIROp_AllocObj: + case kIROp_PackAnyValue: + case kIROp_UnpackAnyValue: + case kIROp_Reinterpret: return false; } } diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 47a724def..47c6621f3 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1488,6 +1488,14 @@ struct IRResultType : IRType IRType* getErrorType() { return (IRType*)getOperand(1); } }; +/// Represents an `Optional<T>`. +struct IROptionalType : IRType +{ + IR_LEAF_ISA(OptionalType) + + IRType* getValueType() { return (IRType*)getOperand(0); } +}; + struct IRTypeType : IRType { IR_LEAF_ISA(TypeType); diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index 3de335c33..353c98fa4 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -383,6 +383,24 @@ public: } return dispatchIfNotNull(expr->base.exp); } + bool visitAsTypeExpr(AsTypeExpr* expr) + { + if (dispatchIfNotNull(expr->value)) + return true; + return dispatchIfNotNull(expr->typeExpr); + } + bool visitIsTypeExpr(IsTypeExpr* expr) + { + if (dispatchIfNotNull(expr->value)) + return true; + return dispatchIfNotNull(expr->typeExpr.exp); + } + bool visitMakeOptionalExpr(MakeOptionalExpr* expr) + { + if (dispatchIfNotNull(expr->typeExpr)) + return true; + return dispatchIfNotNull(expr->value); + } bool visitModifiedTypeExpr(ModifiedTypeExpr* expr) { return dispatchIfNotNull(expr->base.exp); } bool visitTryExpr(TryExpr* expr) { return dispatchIfNotNull(expr->base); } diff --git a/source/slang/slang-language-server-completion.cpp b/source/slang/slang-language-server-completion.cpp index 21586089a..4ae5bcb37 100644 --- a/source/slang/slang-language-server-completion.cpp +++ b/source/slang/slang-language-server-completion.cpp @@ -30,7 +30,8 @@ static const char* kStmtKeywords[] = { "protected", "typedef", "typealias", "uniform", "export", "groupshared", "extension", "associatedtype", "this", "namespace", "This", "using", "__generic", "__exported", "import", "enum", "break", "continue", - "discard", "defer", "cbuffer", "tbuffer", "func"}; + "discard", "defer", "cbuffer", "tbuffer", "func", "is", + "as", "nullptr", "true", "false"}; static const char* hlslSemanticNames[] = { "register", diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index be7373d40..aef60c9d9 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3384,6 +3384,24 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> return LoweredValInfo::simple(irLit); } + LoweredValInfo visitMakeOptionalExpr(MakeOptionalExpr* expr) + { + if (expr->value) + { + auto val = lowerRValueExpr(context, expr->value); + auto optType = lowerType(context, expr->type); + auto irVal = context->irBuilder->emitMakeOptionalValue(optType, val.val); + return LoweredValInfo::simple(irVal); + } + else + { + auto optType = lowerType(context, expr->type); + auto defaultVal = getDefaultVal(as<OptionalType>(expr->type)->getValueType()); + auto irVal = context->irBuilder->emitMakeOptionalNone(optType, defaultVal.val); + return LoweredValInfo::simple(irVal); + } + } + LoweredValInfo visitAggTypeCtorExpr(AggTypeCtorExpr* /*expr*/) { SLANG_UNIMPLEMENTED_X("codegen for aggregate type constructor expression"); @@ -3911,6 +3929,50 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> UNREACHABLE_RETURN(LoweredValInfo()); } + LoweredValInfo visitAsTypeExpr(AsTypeExpr* expr) + { + auto value = lowerLValueExpr(context, expr->value); + auto existentialInfo = value.getExtractedExistentialValInfo(); + auto optType = lowerType(context, expr->type); + SLANG_RELEASE_ASSERT(optType->getOp() == kIROp_OptionalType); + auto targetType = optType->getOperand(0); + auto witness = lowerSimpleVal(context, expr->witnessArg); + auto builder = getBuilder(); + auto var = builder->emitVar(optType); + auto isType = builder->emitIsType(existentialInfo->extractedVal, existentialInfo->witnessTable, targetType, witness); + IRBlock* trueBlock; + IRBlock* falseBlock; + IRBlock* afterBlock; + builder->emitIfElseWithBlocks(isType, trueBlock, falseBlock, afterBlock); + builder->setInsertInto(trueBlock); + auto irVal = builder->emitReinterpret(targetType, existentialInfo->extractedVal); + auto optionalVal = builder->emitMakeOptionalValue(optType, irVal); + builder->emitStore(var, optionalVal); + builder->emitBranch(afterBlock); + builder->setInsertInto(falseBlock); + auto defaultVal = getDefaultVal(as<OptionalType>(expr->type)->getValueType()); + auto noneVal = builder->emitMakeOptionalNone(optType, defaultVal.val); + builder->emitStore(var, noneVal); + builder->emitBranch(afterBlock); + builder->setInsertInto(afterBlock); + auto result = builder->emitLoad(var); + return LoweredValInfo::simple(result); + } + + LoweredValInfo visitIsTypeExpr(IsTypeExpr* expr) + { + if (expr->isAlwaysTrue) + { + return LoweredValInfo::simple(getBuilder()->getBoolValue(true)); + } + auto value = lowerLValueExpr(context, expr->value); + auto type = lowerType(context, expr->type); + auto witness = lowerSimpleVal(context, expr->witnessArg); + auto existentialInfo = value.getExtractedExistentialValInfo(); + auto irVal = getBuilder()->emitIsType(existentialInfo->extractedVal, existentialInfo->witnessTable, type, witness); + return LoweredValInfo::simple(irVal); + } + LoweredValInfo visitModifierCastExpr( ModifierCastExpr* expr) { diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 17794a4b5..4baa7211e 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -4735,9 +4735,9 @@ namespace Slang - Precedence GetOpLevel(Parser* parser, TokenType type) + Precedence GetOpLevel(Parser* parser, const Token& token) { - switch(type) + switch(token.type) { case TokenType::QuestionMark: return Precedence::TernaryConditional; @@ -4790,6 +4790,10 @@ namespace Slang case TokenType::OpMod: return Precedence::Multiplicative; default: + if (token.getContent() == "is" || token.getContent() == "as") + { + return Precedence::RelationalComparison; + } return Precedence::Invalid; } } @@ -4840,16 +4844,39 @@ namespace Slang auto expr = inExpr; for(;;) { - auto opTokenType = parser->tokenReader.peekTokenType(); - auto opPrec = GetOpLevel(parser, opTokenType); + auto opToken = parser->tokenReader.peekToken(); + auto opPrec = GetOpLevel(parser, opToken); if(opPrec < prec) break; + // Special case the "is" and "as" operators. + if (opToken.type == TokenType::Identifier) + { + if (opToken.getContent() == "is") + { + auto isExpr = parser->astBuilder->create<IsTypeExpr>(); + isExpr->value = expr; + parser->ReadToken(); + isExpr->typeExpr = parser->ParseTypeExp(); + expr = isExpr; + continue; + } + else if (opToken.getContent() == "as") + { + auto asExpr = parser->astBuilder->create<AsTypeExpr>(); + asExpr->value = expr; + parser->ReadToken(); + asExpr->typeExpr = parser->ParseType(); + expr = asExpr; + continue; + } + } + auto op = parseOperator(parser); // Special case the `?:` operator since it is the // one non-binary case we need to deal with. - if(opTokenType == TokenType::QuestionMark) + if(opToken.type == TokenType::QuestionMark) { SelectExpr* select = parser->astBuilder->create<SelectExpr>(); select->loc = op->loc; @@ -4869,7 +4896,7 @@ namespace Slang for(;;) { - auto nextOpPrec = GetOpLevel(parser, parser->tokenReader.peekTokenType()); + auto nextOpPrec = GetOpLevel(parser, parser->tokenReader.peekToken()); if((GetAssociativityFromLevel(nextOpPrec) == Associativity::Right) ? (nextOpPrec < opPrec) : (nextOpPrec <= opPrec)) break; @@ -4877,7 +4904,7 @@ namespace Slang right = parseInfixExprWithPrecedence(parser, right, nextOpPrec); } - if (opTokenType == TokenType::OpAssign) + if (opToken.type == TokenType::OpAssign) { AssignExpr* assignExpr = parser->astBuilder->create<AssignExpr>(); assignExpr->loc = op->loc; |
