diff options
| author | Yong He <yonghe@outlook.com> | 2023-08-04 15:47:39 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-08-04 15:47:39 -0700 |
| commit | a2d90fb275962da84611160f8ddd74d934a68dbd (patch) | |
| tree | 066084537b9f4fe1f367de100ed6638a88a028c1 /source/slang/slang-check-expr.cpp | |
| parent | 17da4f0dec2b86ba3a4bdaf8a2ae112047d23623 (diff) | |
Redesign `DeclRef` and systematic `Val` deduplication (#3049)
* Redesign DeclRef + Deduplicate Val.
* Update project files
* Fix warning.
* Fix.
* Fix.
* Remove `Val::_equalsImplOverride`.
* Rmove `Val::_getHashCodeOverride`.
* Remove `semanticVisitor` param from `resolve`.
* Cleanups.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 185 |
1 files changed, 77 insertions, 108 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 3c90c3ed8..e343e3113 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -22,7 +22,7 @@ namespace Slang DeclRefType* SemanticsVisitor::getExprDeclRefType(Expr * expr) { if (auto typetype = as<TypeType>(expr->type)) - return dynamicCast<DeclRefType>(typetype->type); + return dynamicCast<DeclRefType>(typetype->getType()); else return as<DeclRefType>(expr->type); } @@ -154,10 +154,8 @@ namespace Slang // return maybeMoveTemp(expr, [&](DeclRef<VarDeclBase> varDeclRef) { - ExtractExistentialType* openedType = m_astBuilder->create<ExtractExistentialType>(); - openedType->declRef = varDeclRef; - openedType->originalInterfaceType = expr->type.type; - openedType->originalInterfaceDeclRef = interfaceDeclRef; + ExtractExistentialType* openedType = m_astBuilder->getOrCreate<ExtractExistentialType>( + varDeclRef, expr->type.type, interfaceDeclRef); ExtractExistentialValueExpr* openedValue = m_astBuilder->create<ExtractExistentialValueExpr>(); openedValue->declRef = varDeclRef; @@ -202,29 +200,9 @@ namespace Slang if(auto declRefType = as<DeclRefType>(exprType)) { - if(auto interfaceDeclRef = declRefType->declRef.as<InterfaceDecl>()) + if(auto interfaceDeclRef = declRefType->getDeclRef().as<InterfaceDecl>()) { - // Is there an this-type substitution being applied, so that - // we are referencing the interface type through a concrete - // type (e.g., a type parameter constrained to this interface)? - // - // Because of the way that substitutions need to mirror the nesting - // hierarchy of declarations, any this-type substitution pertaining - // to the chosen interface decl must be the first substitution on - // the list (which is a linked list from the "inside" out). - // - auto thisTypeSubst = as<ThisTypeSubstitution>(interfaceDeclRef.getSubst()); - if(thisTypeSubst && thisTypeSubst->interfaceDecl == interfaceDeclRef.getDecl()) - { - // This isn't really an existential type, because somebody - // has already filled in a this-type substitution. - } - else - { - // Okay, here is the case that matters. - // - return openExistential(expr, interfaceDeclRef); - } + return openExistential(expr, interfaceDeclRef); } } @@ -317,7 +295,7 @@ namespace Slang // actually names a type, because in that case we are doing // a static member reference. // - if (auto typeType = as<TypeType>(baseExpr->type)) + if (auto typeType = as<TypeType>(baseExpr->type->getCanonicalType())) { // Before forming the reference, we will check if the // member being referenced can even be used as a static @@ -340,7 +318,7 @@ namespace Slang getSink()->diagnose( loc, Diagnostics::staticRefToNonStaticMember, - typeType->type, + typeType->getType(), declRef.getName()); } @@ -493,9 +471,9 @@ namespace Slang case LookupResultItem::Breadcrumb::Kind::SuperType: { auto witness = as<SubtypeWitness>(breadcrumb->val); - if (auto subDeclRefType = as<DeclRefType>(witness->sub)) + if (auto subDeclRefType = as<DeclRefType>(witness->getSub())) { - if (!as<InterfaceDecl>(subDeclRefType->declRef.getDecl())) + if (!as<InterfaceDecl>(subDeclRefType->getDeclRef().getDecl())) { // Store the inner most concrete super type. subType = subDeclRefType; @@ -515,10 +493,13 @@ namespace Slang return nullptr; // Don't synthesize for generic parameters. - auto parent = as<AggTypeDecl>(subType->declRef.getDecl()); + auto parent = as<AggTypeDecl>(subType->getDeclRef().getDecl()); if (!parent) return nullptr; + // Don't synthesize for ThisType. + if (as<ThisTypeDecl>(subType->getDeclRef().getDecl())) + return nullptr; // If we reach here, we are expecting a synthesized decl defined in `subType`. // Instead of returning a DeclRefExpr to the requirement decl, we synthesize a placeholder decl @@ -607,7 +588,7 @@ namespace Slang // auto witness = as<SubtypeWitness>(breadcrumb->val); SLANG_ASSERT(witness); - auto expr = createCastToSuperTypeExpr(witness->sup, bb, witness); + auto expr = createCastToSuperTypeExpr(witness->getSup(), bb, witness); // Note that we allow a cast of an l-value to // be used as an l-value here because it enables @@ -926,7 +907,7 @@ namespace Slang if (auto declRefType = as<DeclRefType>(type)) { - if (auto builtinRequirement = declRefType->declRef.getDecl()->findModifier<BuiltinRequirementModifier>()) + if (auto builtinRequirement = declRefType->getDeclRef().getDecl()->findModifier<BuiltinRequirementModifier>()) { if (builtinRequirement->kind == BuiltinRequirementKind::DifferentialType) { @@ -935,6 +916,7 @@ namespace Slang return type; } } + type = resolveType(type); if (const auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterfaceType()))) { auto diffTypeLookupResult = lookUpMember( @@ -964,10 +946,10 @@ namespace Slang auto diffTypeExpr = ConstructLookupResultExpr( diffTypeLookupResult.item, baseTypeExpr, - declRefType->declRef.getLoc(), + declRefType->getDeclRef().getLoc(), baseTypeExpr); - return ExtractTypeFromTypeRepr(diffTypeExpr); + return resolveType(ExtractTypeFromTypeRepr(diffTypeExpr)); } } } @@ -991,7 +973,7 @@ namespace Slang SLANG_RELEASE_ASSERT(m_parentDifferentiableAttr); if (witness) { - m_parentDifferentiableAttr->m_mapTypeToIDifferentiableWitness.addIfNotExists(type->declRef, witness); + m_parentDifferentiableAttr->addType(type->getDeclRef(), witness); } } @@ -1048,7 +1030,7 @@ namespace Slang { addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness); } - if (auto aggTypeDeclRef = declRefType->declRef.as<AggTypeDecl>()) + if (auto aggTypeDeclRef = declRefType->getDeclRef().as<AggTypeDecl>()) { foreachDirectOrExtensionMemberOfType<InheritanceDecl>(this, aggTypeDeclRef, [&](DeclRef<InheritanceDecl> member) { @@ -1061,23 +1043,13 @@ namespace Slang maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, fieldType); }); } - for (auto subst = declRefType->declRef.getSubst(); subst; subst = subst->getOuter()) - { - if (auto genSubst = as<GenericSubstitution>(subst)) + SubstitutionSet(declRefType->getDeclRef()).forEachSubstitutionArg([&](Val* arg) { - for (auto arg : genSubst->getArgs()) + if (auto typeArg = as<Type>(arg)) { - if (auto typeArg = as<Type>(arg)) - { - maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, typeArg); - } + maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, typeArg); } - } - else if (auto thisSubst = as<ThisTypeSubstitution>(subst)) - { - maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, thisSubst->witness->sub); - } - } + }); return; } } @@ -1302,7 +1274,7 @@ namespace Slang if (auto constArgVal = as<ConstantIntVal>(argVal)) { - constArgVals[a] = constArgVal->value; + constArgVals[a] = constArgVal->getValue(); } else { @@ -1366,12 +1338,13 @@ namespace Slang || opName == getName("|") || opName == getName("&") || opName == getName("^") || opName == getName("~") || opName == getName("%") || opName == getName("?:") || opName == getName("<<") || opName == getName(">>")) { - auto result = m_astBuilder->create<FuncCallIntVal>(invokeExpr.getExpr()->type.type); - result->args.addRange(argVals, argCount); - result->funcDeclRef = funcDeclRef; - result->funcType = as<Type>(funcDeclRefExpr.getExpr()->type->substitute( - m_astBuilder, funcDeclRefExpr.getSubsts())); - SLANG_RELEASE_ASSERT(result->funcType); + auto result = m_astBuilder->getOrCreate<FuncCallIntVal>( + invokeExpr.getExpr()->type.type, + funcDeclRef, + as<Type>(funcDeclRefExpr.getExpr()->type->substitute( + m_astBuilder, funcDeclRefExpr.getSubsts())), + makeArrayView(argVals, argCount)); + SLANG_RELEASE_ASSERT(result->getFuncType()); return result; } return nullptr; @@ -1507,18 +1480,14 @@ namespace Slang if (isInterfaceRequirement(decl)) { - for (auto subst = declRef.getSubst(); subst; subst = subst->getOuter()) - { - if (auto thisTypeSubst = as<ThisTypeSubstitution>(subst)) - { - auto val = WitnessLookupIntVal::tryFold( - m_astBuilder, - thisTypeSubst->witness, - decl, - declRef.substitute(m_astBuilder, decl->type.type)); - return as<IntVal>(val); - } - } + auto witness = findThisTypeWitness(SubstitutionSet(declRef), as<InterfaceDecl>(decl->parentDecl)); + + auto val = WitnessLookupIntVal::tryFold( + m_astBuilder, + witness, + decl, + declRef.substitute(m_astBuilder, decl->type.type)); + return as<IntVal>(val); } if (!getInitExpr(m_astBuilder, declRef)) @@ -1785,7 +1754,7 @@ namespace Slang getSink()->diagnose(subscriptExpr, Diagnostics::multiDimensionalArrayNotSupported); } - auto elementType = CoerceToUsableType(TypeExp(baseExpr, baseTypeType->type)); + auto elementType = CoerceToUsableType(TypeExp(baseExpr, baseTypeType->getType())); auto arrayType = getArrayType( m_astBuilder, elementType, @@ -1804,7 +1773,7 @@ namespace Slang { return CheckSimpleSubscriptExpr( subscriptExpr, - vecType->elementType); + vecType->getElementType()); } else if (auto matType = as<MatrixExpressionType>(baseType)) { @@ -1975,8 +1944,8 @@ namespace Slang if (basicTypeA && basicTypeB) { - const auto& infoA = BaseTypeInfo::getInfo(basicTypeA->baseType); - const auto& infoB = BaseTypeInfo::getInfo(basicTypeB->baseType); + const auto& infoA = BaseTypeInfo::getInfo(basicTypeA->getBaseType()); + const auto& infoB = BaseTypeInfo::getInfo(basicTypeB->getBaseType()); // TODO(JS): Initially this tries to limit where LValueImplict casts happen. // We could in principal allow different sizes, as long as we converted to a temprorary @@ -2021,7 +1990,7 @@ namespace Slang // if this is still an invoke expression, test arguments passed to inout/out parameter are LValues if(auto funcType = as<FuncType>(invoke->functionExpr->type)) { - if (!funcType->errorType->equals(m_astBuilder->getBottomType())) + if (!funcType->getErrorType()->equals(m_astBuilder->getBottomType())) { // If the callee throws, make sure we are inside a try clause. if (m_enclosingTryClauseType == TryClauseType::None) @@ -2230,7 +2199,7 @@ namespace Slang return result; } - Expr* SemanticsExprVisitor::visitInvokeExpr(InvokeExpr *expr) + Expr* SemanticsExprVisitor::visitInvokeExpr(InvokeExpr* expr) { // check the base expression first expr->functionExpr = CheckTerm(expr->functionExpr); @@ -2312,6 +2281,7 @@ namespace Slang auto lookupResult = lookUp( m_astBuilder, this, expr->name, expr->scope); + if (expr->name == getSession()->getCompletionRequestTokenName()) { auto scopeKind = CompletionSuggestions::ScopeKind::Expr; @@ -2357,7 +2327,7 @@ namespace Slang if (auto modifiedType = as<ModifiedType>(primalType)) { if (modifiedType->findModifier<NoDiffModifierVal>()) - return modifiedType->base; + return modifiedType->getBase(); } // Get a reference to the builtin 'IDifferentiable' interface @@ -2379,23 +2349,23 @@ namespace Slang // Resolve JVP type here. // Note that this type checking needs to be in sync with // the auto-generation logic in slang-ir-jvp-diff.cpp - - FuncType* jvpType = m_astBuilder->create<FuncType>(); + List<Type*> paramTypes; // The JVP return type is float if primal return type is float // void otherwise. // - jvpType->resultType = getDifferentialPairType(originalType->getResultType()); + auto resultType = getDifferentialPairType(originalType->getResultType()); // No support for differentiating function that throw errors, for now. - SLANG_ASSERT(originalType->errorType->equals(m_astBuilder->getBottomType())); - jvpType->errorType = originalType->errorType; + SLANG_ASSERT(originalType->getErrorType()->equals(m_astBuilder->getBottomType())); + auto errorType = originalType->getErrorType(); for (Index i = 0; i < originalType->getParamCount(); i++) { if(auto jvpParamType = _toDifferentialParamType(originalType->getParamType(i))) - jvpType->paramTypes.add(jvpParamType); + paramTypes.add(jvpParamType); } + FuncType* jvpType = m_astBuilder->getOrCreate<FuncType>(paramTypes.getArrayView(), resultType, errorType); return jvpType; } @@ -2405,16 +2375,15 @@ namespace Slang // Resolve backward diff type here. // Note that this type checking needs to be in sync with // the auto-generation logic in slang-ir-jvp-diff.cpp - - FuncType* type = m_astBuilder->create<FuncType>(); + List<Type*> paramTypes; // The backward diff return type is void // - type->resultType = m_astBuilder->getVoidType(); + auto resultType = m_astBuilder->getVoidType(); // No support for differentiating function that throw errors, for now. - SLANG_ASSERT(originalType->errorType->equals(m_astBuilder->getBottomType())); - type->errorType = originalType->errorType; + SLANG_ASSERT(originalType->getErrorType()->equals(m_astBuilder->getBottomType())); + auto errorType = originalType->getErrorType(); for (Index i = 0; i < originalType->getParamCount(); i++) { @@ -2424,7 +2393,7 @@ namespace Slang tryGetDifferentialType(m_astBuilder, outType->getValueType()); if (diffElementType) { - type->paramTypes.add(diffElementType); + paramTypes.add(diffElementType); } else { @@ -2447,16 +2416,16 @@ namespace Slang derivType = inoutType->getValueType(); } } - type->paramTypes.add(derivType); + paramTypes.add(derivType); } } // Last parameter is the initial derivative of the original return type - auto dOutType = tryGetDifferentialType(m_astBuilder, originalType->resultType); + auto dOutType = tryGetDifferentialType(m_astBuilder, originalType->getResultType()); if (dOutType) - type->paramTypes.add(dOutType); + paramTypes.add(dOutType); - return type; + return m_astBuilder->getOrCreate<FuncType>(paramTypes.getArrayView(), resultType, errorType); } struct HigherOrderInvokeExprCheckingActions @@ -2473,9 +2442,8 @@ namespace Slang if (auto baseFuncGenericDeclRef = declRefExpr->declRef.as<GenericDecl>()) { // Get inner function - DeclRef<Decl> unspecializedInnerRef = astBuilder->getSpecializedDeclRef<Decl>( - getInner(baseFuncGenericDeclRef), - baseFuncGenericDeclRef.getSubst()); + DeclRef<Decl> unspecializedInnerRef = createDefaultSubstitutionsIfNeeded(astBuilder, semantics, + astBuilder->getMemberDeclRef(baseFuncGenericDeclRef, getInner(baseFuncGenericDeclRef))); auto callableDeclRef = unspecializedInnerRef.as<CallableDecl>(); if (!callableDeclRef) return nullptr; @@ -2677,10 +2645,10 @@ namespace Slang return false; if (!isIntegerBaseType(getVectorBaseType(vectorType))) return false; - auto constElementCount = as<ConstantIntVal>(vectorType->elementCount); + auto constElementCount = as<ConstantIntVal>(vectorType->getElementCount()); if (!constElementCount) return false; - return constElementCount->value == 3; + return constElementCount->getValue() == 3; }; expr->threadGroupSize = dispatchExpr(expr->threadGroupSize, *this); if (!isInt3Type(expr->threadGroupSize->type.type)) @@ -2836,7 +2804,7 @@ namespace Slang // if( auto declRefType = as<DeclRefType>(typeExp.type) ) { - if(const auto structDeclRef = as<StructDecl>(declRefType->declRef)) + if(const auto structDeclRef = as<StructDecl>(declRefType->getDeclRef())) { if( expr->arguments.getCount() == 1 ) { @@ -3051,7 +3019,7 @@ namespace Slang auto baseType = expr->type; if (auto pointerLikeType = as<PointerLikeType>(baseType)) { - auto elementType = QualType(pointerLikeType->elementType); + auto elementType = QualType(pointerLikeType->getElementType()); elementType.isLeftValue = baseType.isLeftValue; auto derefExpr = m_astBuilder->create<DerefExpr>(); @@ -3230,7 +3198,7 @@ namespace Slang if (auto constantColCount = as<ConstantIntVal>(baseColCount)) { return CheckMatrixSwizzleExpr(memberRefExpr, baseElementType, - constantRowCount->value, constantColCount->value); + constantRowCount->getValue(), constantColCount->getValue()); } } getSink()->diagnose(memberRefExpr, Diagnostics::unimplemented, "swizzle on matrix of unknown size"); @@ -3350,7 +3318,7 @@ namespace Slang { if (auto constantElementCount = as<ConstantIntVal>(baseElementCount)) { - return CheckSwizzleExpr(memberRefExpr, baseElementType, constantElementCount->value); + return CheckSwizzleExpr(memberRefExpr, baseElementType, constantElementCount->getValue()); } else { @@ -3381,6 +3349,7 @@ namespace Slang m_astBuilder, this, expr->name, + namespaceDeclRef.getDecl(), namespaceDeclRef); if (!lookupResult.isValid()) { @@ -3406,7 +3375,7 @@ namespace Slang // // TODO: this duplicates a *lot* of logic with the case below. // We need to fix that. - auto type = typeType->type; + auto type = typeType->getType(); if (as<ErrorType>(type)) { @@ -3577,7 +3546,7 @@ namespace Slang for (auto lookupResult : overloadedExpr->lookupResult2) { bool shouldRemove = false; - if (lookupResult.declRef.getParent(m_astBuilder).as<InterfaceDecl>()) + if (lookupResult.declRef.getParent().as<InterfaceDecl>()) { shouldRemove = true; } @@ -3627,8 +3596,8 @@ namespace Slang { return CheckSwizzleExpr( expr, - baseVecType->elementType, - baseVecType->elementCount); + baseVecType->getElementType(), + baseVecType->getElementCount()); } else if(auto baseScalarType = as<BasicExpressionType>(baseType)) { @@ -3893,7 +3862,7 @@ namespace Slang types.reserve(expr->parameters.getCount()); for(const auto& t : expr->parameters) types.add(t.type); - auto funcType = m_astBuilder->getFuncType(std::move(types), expr->result.type); + auto funcType = m_astBuilder->getFuncType(types.getArrayView(), expr->result.type); expr->type = m_astBuilder->getTypeType(funcType); return expr; |
