summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-08-04 15:47:39 -0700
committerGitHub <noreply@github.com>2023-08-04 15:47:39 -0700
commita2d90fb275962da84611160f8ddd74d934a68dbd (patch)
tree066084537b9f4fe1f367de100ed6638a88a028c1 /source/slang/slang-check-expr.cpp
parent17da4f0dec2b86ba3a4bdaf8a2ae112047d23623 (diff)
Redesign `DeclRef` and systematic `Val` deduplication (#3049)
* Redesign DeclRef + Deduplicate Val. * Update project files * Fix warning. * Fix. * Fix. * Remove `Val::_equalsImplOverride`. * Rmove `Val::_getHashCodeOverride`. * Remove `semanticVisitor` param from `resolve`. * Cleanups. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp185
1 files changed, 77 insertions, 108 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 3c90c3ed8..e343e3113 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -22,7 +22,7 @@ namespace Slang
DeclRefType* SemanticsVisitor::getExprDeclRefType(Expr * expr)
{
if (auto typetype = as<TypeType>(expr->type))
- return dynamicCast<DeclRefType>(typetype->type);
+ return dynamicCast<DeclRefType>(typetype->getType());
else
return as<DeclRefType>(expr->type);
}
@@ -154,10 +154,8 @@ namespace Slang
//
return maybeMoveTemp(expr, [&](DeclRef<VarDeclBase> varDeclRef)
{
- ExtractExistentialType* openedType = m_astBuilder->create<ExtractExistentialType>();
- openedType->declRef = varDeclRef;
- openedType->originalInterfaceType = expr->type.type;
- openedType->originalInterfaceDeclRef = interfaceDeclRef;
+ ExtractExistentialType* openedType = m_astBuilder->getOrCreate<ExtractExistentialType>(
+ varDeclRef, expr->type.type, interfaceDeclRef);
ExtractExistentialValueExpr* openedValue = m_astBuilder->create<ExtractExistentialValueExpr>();
openedValue->declRef = varDeclRef;
@@ -202,29 +200,9 @@ namespace Slang
if(auto declRefType = as<DeclRefType>(exprType))
{
- if(auto interfaceDeclRef = declRefType->declRef.as<InterfaceDecl>())
+ if(auto interfaceDeclRef = declRefType->getDeclRef().as<InterfaceDecl>())
{
- // Is there an this-type substitution being applied, so that
- // we are referencing the interface type through a concrete
- // type (e.g., a type parameter constrained to this interface)?
- //
- // Because of the way that substitutions need to mirror the nesting
- // hierarchy of declarations, any this-type substitution pertaining
- // to the chosen interface decl must be the first substitution on
- // the list (which is a linked list from the "inside" out).
- //
- auto thisTypeSubst = as<ThisTypeSubstitution>(interfaceDeclRef.getSubst());
- if(thisTypeSubst && thisTypeSubst->interfaceDecl == interfaceDeclRef.getDecl())
- {
- // This isn't really an existential type, because somebody
- // has already filled in a this-type substitution.
- }
- else
- {
- // Okay, here is the case that matters.
- //
- return openExistential(expr, interfaceDeclRef);
- }
+ return openExistential(expr, interfaceDeclRef);
}
}
@@ -317,7 +295,7 @@ namespace Slang
// actually names a type, because in that case we are doing
// a static member reference.
//
- if (auto typeType = as<TypeType>(baseExpr->type))
+ if (auto typeType = as<TypeType>(baseExpr->type->getCanonicalType()))
{
// Before forming the reference, we will check if the
// member being referenced can even be used as a static
@@ -340,7 +318,7 @@ namespace Slang
getSink()->diagnose(
loc,
Diagnostics::staticRefToNonStaticMember,
- typeType->type,
+ typeType->getType(),
declRef.getName());
}
@@ -493,9 +471,9 @@ namespace Slang
case LookupResultItem::Breadcrumb::Kind::SuperType:
{
auto witness = as<SubtypeWitness>(breadcrumb->val);
- if (auto subDeclRefType = as<DeclRefType>(witness->sub))
+ if (auto subDeclRefType = as<DeclRefType>(witness->getSub()))
{
- if (!as<InterfaceDecl>(subDeclRefType->declRef.getDecl()))
+ if (!as<InterfaceDecl>(subDeclRefType->getDeclRef().getDecl()))
{
// Store the inner most concrete super type.
subType = subDeclRefType;
@@ -515,10 +493,13 @@ namespace Slang
return nullptr;
// Don't synthesize for generic parameters.
- auto parent = as<AggTypeDecl>(subType->declRef.getDecl());
+ auto parent = as<AggTypeDecl>(subType->getDeclRef().getDecl());
if (!parent)
return nullptr;
+ // Don't synthesize for ThisType.
+ if (as<ThisTypeDecl>(subType->getDeclRef().getDecl()))
+ return nullptr;
// If we reach here, we are expecting a synthesized decl defined in `subType`.
// Instead of returning a DeclRefExpr to the requirement decl, we synthesize a placeholder decl
@@ -607,7 +588,7 @@ namespace Slang
//
auto witness = as<SubtypeWitness>(breadcrumb->val);
SLANG_ASSERT(witness);
- auto expr = createCastToSuperTypeExpr(witness->sup, bb, witness);
+ auto expr = createCastToSuperTypeExpr(witness->getSup(), bb, witness);
// Note that we allow a cast of an l-value to
// be used as an l-value here because it enables
@@ -926,7 +907,7 @@ namespace Slang
if (auto declRefType = as<DeclRefType>(type))
{
- if (auto builtinRequirement = declRefType->declRef.getDecl()->findModifier<BuiltinRequirementModifier>())
+ if (auto builtinRequirement = declRefType->getDeclRef().getDecl()->findModifier<BuiltinRequirementModifier>())
{
if (builtinRequirement->kind == BuiltinRequirementKind::DifferentialType)
{
@@ -935,6 +916,7 @@ namespace Slang
return type;
}
}
+ type = resolveType(type);
if (const auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterfaceType())))
{
auto diffTypeLookupResult = lookUpMember(
@@ -964,10 +946,10 @@ namespace Slang
auto diffTypeExpr = ConstructLookupResultExpr(
diffTypeLookupResult.item,
baseTypeExpr,
- declRefType->declRef.getLoc(),
+ declRefType->getDeclRef().getLoc(),
baseTypeExpr);
- return ExtractTypeFromTypeRepr(diffTypeExpr);
+ return resolveType(ExtractTypeFromTypeRepr(diffTypeExpr));
}
}
}
@@ -991,7 +973,7 @@ namespace Slang
SLANG_RELEASE_ASSERT(m_parentDifferentiableAttr);
if (witness)
{
- m_parentDifferentiableAttr->m_mapTypeToIDifferentiableWitness.addIfNotExists(type->declRef, witness);
+ m_parentDifferentiableAttr->addType(type->getDeclRef(), witness);
}
}
@@ -1048,7 +1030,7 @@ namespace Slang
{
addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness);
}
- if (auto aggTypeDeclRef = declRefType->declRef.as<AggTypeDecl>())
+ if (auto aggTypeDeclRef = declRefType->getDeclRef().as<AggTypeDecl>())
{
foreachDirectOrExtensionMemberOfType<InheritanceDecl>(this, aggTypeDeclRef, [&](DeclRef<InheritanceDecl> member)
{
@@ -1061,23 +1043,13 @@ namespace Slang
maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, fieldType);
});
}
- for (auto subst = declRefType->declRef.getSubst(); subst; subst = subst->getOuter())
- {
- if (auto genSubst = as<GenericSubstitution>(subst))
+ SubstitutionSet(declRefType->getDeclRef()).forEachSubstitutionArg([&](Val* arg)
{
- for (auto arg : genSubst->getArgs())
+ if (auto typeArg = as<Type>(arg))
{
- if (auto typeArg = as<Type>(arg))
- {
- maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, typeArg);
- }
+ maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, typeArg);
}
- }
- else if (auto thisSubst = as<ThisTypeSubstitution>(subst))
- {
- maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, thisSubst->witness->sub);
- }
- }
+ });
return;
}
}
@@ -1302,7 +1274,7 @@ namespace Slang
if (auto constArgVal = as<ConstantIntVal>(argVal))
{
- constArgVals[a] = constArgVal->value;
+ constArgVals[a] = constArgVal->getValue();
}
else
{
@@ -1366,12 +1338,13 @@ namespace Slang
|| opName == getName("|") || opName == getName("&") || opName == getName("^") || opName == getName("~") || opName == getName("%") ||
opName == getName("?:") || opName == getName("<<") || opName == getName(">>"))
{
- auto result = m_astBuilder->create<FuncCallIntVal>(invokeExpr.getExpr()->type.type);
- result->args.addRange(argVals, argCount);
- result->funcDeclRef = funcDeclRef;
- result->funcType = as<Type>(funcDeclRefExpr.getExpr()->type->substitute(
- m_astBuilder, funcDeclRefExpr.getSubsts()));
- SLANG_RELEASE_ASSERT(result->funcType);
+ auto result = m_astBuilder->getOrCreate<FuncCallIntVal>(
+ invokeExpr.getExpr()->type.type,
+ funcDeclRef,
+ as<Type>(funcDeclRefExpr.getExpr()->type->substitute(
+ m_astBuilder, funcDeclRefExpr.getSubsts())),
+ makeArrayView(argVals, argCount));
+ SLANG_RELEASE_ASSERT(result->getFuncType());
return result;
}
return nullptr;
@@ -1507,18 +1480,14 @@ namespace Slang
if (isInterfaceRequirement(decl))
{
- for (auto subst = declRef.getSubst(); subst; subst = subst->getOuter())
- {
- if (auto thisTypeSubst = as<ThisTypeSubstitution>(subst))
- {
- auto val = WitnessLookupIntVal::tryFold(
- m_astBuilder,
- thisTypeSubst->witness,
- decl,
- declRef.substitute(m_astBuilder, decl->type.type));
- return as<IntVal>(val);
- }
- }
+ auto witness = findThisTypeWitness(SubstitutionSet(declRef), as<InterfaceDecl>(decl->parentDecl));
+
+ auto val = WitnessLookupIntVal::tryFold(
+ m_astBuilder,
+ witness,
+ decl,
+ declRef.substitute(m_astBuilder, decl->type.type));
+ return as<IntVal>(val);
}
if (!getInitExpr(m_astBuilder, declRef))
@@ -1785,7 +1754,7 @@ namespace Slang
getSink()->diagnose(subscriptExpr, Diagnostics::multiDimensionalArrayNotSupported);
}
- auto elementType = CoerceToUsableType(TypeExp(baseExpr, baseTypeType->type));
+ auto elementType = CoerceToUsableType(TypeExp(baseExpr, baseTypeType->getType()));
auto arrayType = getArrayType(
m_astBuilder,
elementType,
@@ -1804,7 +1773,7 @@ namespace Slang
{
return CheckSimpleSubscriptExpr(
subscriptExpr,
- vecType->elementType);
+ vecType->getElementType());
}
else if (auto matType = as<MatrixExpressionType>(baseType))
{
@@ -1975,8 +1944,8 @@ namespace Slang
if (basicTypeA && basicTypeB)
{
- const auto& infoA = BaseTypeInfo::getInfo(basicTypeA->baseType);
- const auto& infoB = BaseTypeInfo::getInfo(basicTypeB->baseType);
+ const auto& infoA = BaseTypeInfo::getInfo(basicTypeA->getBaseType());
+ const auto& infoB = BaseTypeInfo::getInfo(basicTypeB->getBaseType());
// TODO(JS): Initially this tries to limit where LValueImplict casts happen.
// We could in principal allow different sizes, as long as we converted to a temprorary
@@ -2021,7 +1990,7 @@ namespace Slang
// if this is still an invoke expression, test arguments passed to inout/out parameter are LValues
if(auto funcType = as<FuncType>(invoke->functionExpr->type))
{
- if (!funcType->errorType->equals(m_astBuilder->getBottomType()))
+ if (!funcType->getErrorType()->equals(m_astBuilder->getBottomType()))
{
// If the callee throws, make sure we are inside a try clause.
if (m_enclosingTryClauseType == TryClauseType::None)
@@ -2230,7 +2199,7 @@ namespace Slang
return result;
}
- Expr* SemanticsExprVisitor::visitInvokeExpr(InvokeExpr *expr)
+ Expr* SemanticsExprVisitor::visitInvokeExpr(InvokeExpr* expr)
{
// check the base expression first
expr->functionExpr = CheckTerm(expr->functionExpr);
@@ -2312,6 +2281,7 @@ namespace Slang
auto lookupResult = lookUp(
m_astBuilder,
this, expr->name, expr->scope);
+
if (expr->name == getSession()->getCompletionRequestTokenName())
{
auto scopeKind = CompletionSuggestions::ScopeKind::Expr;
@@ -2357,7 +2327,7 @@ namespace Slang
if (auto modifiedType = as<ModifiedType>(primalType))
{
if (modifiedType->findModifier<NoDiffModifierVal>())
- return modifiedType->base;
+ return modifiedType->getBase();
}
// Get a reference to the builtin 'IDifferentiable' interface
@@ -2379,23 +2349,23 @@ namespace Slang
// Resolve JVP type here.
// Note that this type checking needs to be in sync with
// the auto-generation logic in slang-ir-jvp-diff.cpp
-
- FuncType* jvpType = m_astBuilder->create<FuncType>();
+ List<Type*> paramTypes;
// The JVP return type is float if primal return type is float
// void otherwise.
//
- jvpType->resultType = getDifferentialPairType(originalType->getResultType());
+ auto resultType = getDifferentialPairType(originalType->getResultType());
// No support for differentiating function that throw errors, for now.
- SLANG_ASSERT(originalType->errorType->equals(m_astBuilder->getBottomType()));
- jvpType->errorType = originalType->errorType;
+ SLANG_ASSERT(originalType->getErrorType()->equals(m_astBuilder->getBottomType()));
+ auto errorType = originalType->getErrorType();
for (Index i = 0; i < originalType->getParamCount(); i++)
{
if(auto jvpParamType = _toDifferentialParamType(originalType->getParamType(i)))
- jvpType->paramTypes.add(jvpParamType);
+ paramTypes.add(jvpParamType);
}
+ FuncType* jvpType = m_astBuilder->getOrCreate<FuncType>(paramTypes.getArrayView(), resultType, errorType);
return jvpType;
}
@@ -2405,16 +2375,15 @@ namespace Slang
// Resolve backward diff type here.
// Note that this type checking needs to be in sync with
// the auto-generation logic in slang-ir-jvp-diff.cpp
-
- FuncType* type = m_astBuilder->create<FuncType>();
+ List<Type*> paramTypes;
// The backward diff return type is void
//
- type->resultType = m_astBuilder->getVoidType();
+ auto resultType = m_astBuilder->getVoidType();
// No support for differentiating function that throw errors, for now.
- SLANG_ASSERT(originalType->errorType->equals(m_astBuilder->getBottomType()));
- type->errorType = originalType->errorType;
+ SLANG_ASSERT(originalType->getErrorType()->equals(m_astBuilder->getBottomType()));
+ auto errorType = originalType->getErrorType();
for (Index i = 0; i < originalType->getParamCount(); i++)
{
@@ -2424,7 +2393,7 @@ namespace Slang
tryGetDifferentialType(m_astBuilder, outType->getValueType());
if (diffElementType)
{
- type->paramTypes.add(diffElementType);
+ paramTypes.add(diffElementType);
}
else
{
@@ -2447,16 +2416,16 @@ namespace Slang
derivType = inoutType->getValueType();
}
}
- type->paramTypes.add(derivType);
+ paramTypes.add(derivType);
}
}
// Last parameter is the initial derivative of the original return type
- auto dOutType = tryGetDifferentialType(m_astBuilder, originalType->resultType);
+ auto dOutType = tryGetDifferentialType(m_astBuilder, originalType->getResultType());
if (dOutType)
- type->paramTypes.add(dOutType);
+ paramTypes.add(dOutType);
- return type;
+ return m_astBuilder->getOrCreate<FuncType>(paramTypes.getArrayView(), resultType, errorType);
}
struct HigherOrderInvokeExprCheckingActions
@@ -2473,9 +2442,8 @@ namespace Slang
if (auto baseFuncGenericDeclRef = declRefExpr->declRef.as<GenericDecl>())
{
// Get inner function
- DeclRef<Decl> unspecializedInnerRef = astBuilder->getSpecializedDeclRef<Decl>(
- getInner(baseFuncGenericDeclRef),
- baseFuncGenericDeclRef.getSubst());
+ DeclRef<Decl> unspecializedInnerRef = createDefaultSubstitutionsIfNeeded(astBuilder, semantics,
+ astBuilder->getMemberDeclRef(baseFuncGenericDeclRef, getInner(baseFuncGenericDeclRef)));
auto callableDeclRef = unspecializedInnerRef.as<CallableDecl>();
if (!callableDeclRef)
return nullptr;
@@ -2677,10 +2645,10 @@ namespace Slang
return false;
if (!isIntegerBaseType(getVectorBaseType(vectorType)))
return false;
- auto constElementCount = as<ConstantIntVal>(vectorType->elementCount);
+ auto constElementCount = as<ConstantIntVal>(vectorType->getElementCount());
if (!constElementCount)
return false;
- return constElementCount->value == 3;
+ return constElementCount->getValue() == 3;
};
expr->threadGroupSize = dispatchExpr(expr->threadGroupSize, *this);
if (!isInt3Type(expr->threadGroupSize->type.type))
@@ -2836,7 +2804,7 @@ namespace Slang
//
if( auto declRefType = as<DeclRefType>(typeExp.type) )
{
- if(const auto structDeclRef = as<StructDecl>(declRefType->declRef))
+ if(const auto structDeclRef = as<StructDecl>(declRefType->getDeclRef()))
{
if( expr->arguments.getCount() == 1 )
{
@@ -3051,7 +3019,7 @@ namespace Slang
auto baseType = expr->type;
if (auto pointerLikeType = as<PointerLikeType>(baseType))
{
- auto elementType = QualType(pointerLikeType->elementType);
+ auto elementType = QualType(pointerLikeType->getElementType());
elementType.isLeftValue = baseType.isLeftValue;
auto derefExpr = m_astBuilder->create<DerefExpr>();
@@ -3230,7 +3198,7 @@ namespace Slang
if (auto constantColCount = as<ConstantIntVal>(baseColCount))
{
return CheckMatrixSwizzleExpr(memberRefExpr, baseElementType,
- constantRowCount->value, constantColCount->value);
+ constantRowCount->getValue(), constantColCount->getValue());
}
}
getSink()->diagnose(memberRefExpr, Diagnostics::unimplemented, "swizzle on matrix of unknown size");
@@ -3350,7 +3318,7 @@ namespace Slang
{
if (auto constantElementCount = as<ConstantIntVal>(baseElementCount))
{
- return CheckSwizzleExpr(memberRefExpr, baseElementType, constantElementCount->value);
+ return CheckSwizzleExpr(memberRefExpr, baseElementType, constantElementCount->getValue());
}
else
{
@@ -3381,6 +3349,7 @@ namespace Slang
m_astBuilder,
this,
expr->name,
+ namespaceDeclRef.getDecl(),
namespaceDeclRef);
if (!lookupResult.isValid())
{
@@ -3406,7 +3375,7 @@ namespace Slang
//
// TODO: this duplicates a *lot* of logic with the case below.
// We need to fix that.
- auto type = typeType->type;
+ auto type = typeType->getType();
if (as<ErrorType>(type))
{
@@ -3577,7 +3546,7 @@ namespace Slang
for (auto lookupResult : overloadedExpr->lookupResult2)
{
bool shouldRemove = false;
- if (lookupResult.declRef.getParent(m_astBuilder).as<InterfaceDecl>())
+ if (lookupResult.declRef.getParent().as<InterfaceDecl>())
{
shouldRemove = true;
}
@@ -3627,8 +3596,8 @@ namespace Slang
{
return CheckSwizzleExpr(
expr,
- baseVecType->elementType,
- baseVecType->elementCount);
+ baseVecType->getElementType(),
+ baseVecType->getElementCount());
}
else if(auto baseScalarType = as<BasicExpressionType>(baseType))
{
@@ -3893,7 +3862,7 @@ namespace Slang
types.reserve(expr->parameters.getCount());
for(const auto& t : expr->parameters)
types.add(t.type);
- auto funcType = m_astBuilder->getFuncType(std::move(types), expr->result.type);
+ auto funcType = m_astBuilder->getFuncType(types.getArrayView(), expr->result.type);
expr->type = m_astBuilder->getTypeType(funcType);
return expr;