summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-09-07 23:01:53 -0700
committerGitHub <noreply@github.com>2023-09-07 23:01:53 -0700
commitcb5dd19992fb77ca2be866d9c6f2f4436c8b1c1e (patch)
tree4a24573f9da79618c0e65e7462101ab3d0b640c4 /source/slang
parenta7fa215e81e510de34ac96778ac6320cbb642d64 (diff)
Incur l-value conversion cost during overload resolution. (#3195)
* Incur l-value conversion cost during overload resolution. * Fix compile error. * cleanup. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-ast-support-types.h10
-rw-r--r--source/slang/slang-check-constraint.cpp79
-rw-r--r--source/slang/slang-check-conversion.cpp17
-rw-r--r--source/slang/slang-check-impl.h65
-rw-r--r--source/slang/slang-check-overload.cpp58
5 files changed, 140 insertions, 89 deletions
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index 5d7ca49cb..0b4e9cab2 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -140,6 +140,9 @@ namespace Slang
// the element type of the vector)
kConversionCost_ScalarToVector = 1,
+ // Additional cost when casting an LValue.
+ kConversionCost_LValueCast = 800,
+
// Conversion is impossible
kConversionCost_Impossible = 0xFFFFFFFF,
};
@@ -521,6 +524,13 @@ namespace Slang
QualType(Type* type);
+ QualType(Type* type, bool isLVal)
+ : QualType(type)
+ {
+ isLeftValue = isLVal;
+ }
+
+
Type* Ptr() { return type; }
operator Type*() { return type; }
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp
index 6e600c4af..996b88f48 100644
--- a/source/slang/slang-check-constraint.cpp
+++ b/source/slang/slang-check-constraint.cpp
@@ -174,8 +174,8 @@ namespace Slang
}
Type* SemanticsVisitor::TryJoinTypes(
- Type* left,
- Type* right)
+ QualType left,
+ QualType right)
{
// Easy case: they are the same type!
if (left->equals(right))
@@ -186,8 +186,8 @@ namespace Slang
{
if (auto rightBasic = as<BasicExpressionType>(right))
{
- auto costConvertRightToLeft = getConversionCost(leftBasic, rightBasic);
- auto costConvertLeftToRight = getConversionCost(rightBasic, leftBasic);
+ auto costConvertRightToLeft = getConversionCost(leftBasic, right);
+ auto costConvertLeftToRight = getConversionCost(rightBasic, left);
// Return the one that had lower conversion cost.
if (costConvertRightToLeft > costConvertLeftToRight)
@@ -217,8 +217,8 @@ namespace Slang
// Try to join the element types
auto joinElementType = TryJoinTypes(
- leftVector->getElementType(),
- rightVector->getElementType());
+ QualType(leftVector->getElementType(), left.isLeftValue),
+ QualType(rightVector->getElementType(), right.isLeftValue));
if(!joinElementType)
return nullptr;
@@ -339,13 +339,13 @@ namespace Slang
continue;
}
- Type* type = nullptr;
+ QualType type;
for (auto& c : system->constraints)
{
if (c.decl != typeParam.getDecl())
continue;
- auto cType = as<Type>(c.val);
+ auto cType = QualType(as<Type>(c.val), c.isUsedAsLValue);
SLANG_RELEASE_ASSERT(cType);
if (!type)
@@ -360,7 +360,7 @@ namespace Slang
// failure!
return DeclRef<Decl>();
}
- type = joinType;
+ type = QualType(joinType, type.isLeftValue || cType.isLeftValue);
}
c.satisfied = true;
@@ -505,14 +505,16 @@ namespace Slang
bool SemanticsVisitor::TryUnifyVals(
ConstraintSystem& constraints,
Val* fst,
- Val* snd)
+ bool fstLVal,
+ Val* snd,
+ bool sndLVal)
{
// if both values are types, then unify types
if (auto fstType = as<Type>(fst))
{
if (auto sndType = as<Type>(snd))
{
- return TryUnifyTypes(constraints, fstType, sndType);
+ return TryUnifyTypes(constraints, QualType(fstType, fstLVal), QualType(sndType, sndLVal));
}
}
@@ -582,7 +584,9 @@ namespace Slang
bool SemanticsVisitor::tryUnifyDeclRef(
ConstraintSystem& constraints,
DeclRefBase* fst,
- DeclRefBase* snd)
+ bool fstIsLVal,
+ DeclRefBase* snd,
+ bool sndIsLVal)
{
if (fst == snd)
return true;
@@ -594,13 +598,15 @@ namespace Slang
return true;
if (fstGen == nullptr || sndGen == nullptr)
return false;
- return tryUnifyGenericAppDeclRef(constraints, fstGen, sndGen);
+ return tryUnifyGenericAppDeclRef(constraints, fstGen, fstIsLVal, sndGen, sndIsLVal);
}
bool SemanticsVisitor::tryUnifyGenericAppDeclRef(
ConstraintSystem& constraints,
GenericAppDeclRef* fst,
- GenericAppDeclRef* snd)
+ bool fstIsLVal,
+ GenericAppDeclRef* snd,
+ bool sndIsLVal)
{
SLANG_ASSERT(fst);
SLANG_ASSERT(snd);
@@ -617,7 +623,7 @@ namespace Slang
bool okay = true;
for (Index aa = 0; aa < argCount; ++aa)
{
- if (!TryUnifyVals(constraints, fstGen->getArgs()[aa], sndGen->getArgs()[aa]))
+ if (!TryUnifyVals(constraints, fstGen->getArgs()[aa], fstIsLVal, sndGen->getArgs()[aa], sndIsLVal))
{
okay = false;
}
@@ -627,7 +633,7 @@ namespace Slang
auto fstBase = fst->getBase();
auto sndBase = snd->getBase();
- if (!tryUnifyDeclRef(constraints, fstBase, sndBase))
+ if (!tryUnifyDeclRef(constraints, fstBase, fstIsLVal, sndBase, sndIsLVal))
{
okay = false;
}
@@ -636,16 +642,16 @@ namespace Slang
}
bool SemanticsVisitor::TryUnifyTypeParam(
- ConstraintSystem& constraints,
+ ConstraintSystem& constraints,
GenericTypeParamDecl* typeParamDecl,
- Type* type)
+ QualType type)
{
// We want to constrain the given type parameter
// to equal the given type.
Constraint constraint;
constraint.decl = typeParamDecl;
constraint.val = type;
-
+ constraint.isUsedAsLValue = type.isLeftValue;
constraints.constraints.add(constraint);
return true;
@@ -691,8 +697,8 @@ namespace Slang
bool SemanticsVisitor::TryUnifyTypesByStructuralMatch(
ConstraintSystem& constraints,
- Type* fst,
- Type* snd)
+ QualType fst,
+ QualType snd)
{
if (auto fstDeclRefType = as<DeclRefType>(fst))
{
@@ -718,14 +724,17 @@ namespace Slang
if (!tryUnifyDeclRef(
constraints,
fstDeclRef,
- sndDeclRef))
+ fst.isLeftValue,
+ sndDeclRef,
+ snd.isLeftValue))
{
return false;
}
return true;
}
- } else if(auto fstFunType = as<FuncType>(fst))
+ }
+ else if(auto fstFunType = as<FuncType>(fst))
{
if (auto sndFunType = as<FuncType>(snd))
{
@@ -746,8 +755,8 @@ namespace Slang
bool SemanticsVisitor::TryUnifyConjunctionType(
ConstraintSystem& constraints,
- Type* fst,
- Type* snd)
+ QualType fst,
+ QualType snd)
{
// Unifying a type `A & B` with `T` amounts to unifying
// `A` with `T` and also `B` with `T` while
@@ -759,13 +768,13 @@ namespace Slang
//
if (auto fstAndType = as<AndType>(fst))
{
- return TryUnifyTypes(constraints, fstAndType->getLeft(), snd)
- && TryUnifyTypes(constraints, fstAndType->getRight(), snd);
+ return TryUnifyTypes(constraints, QualType(fstAndType->getLeft(), fst.isLeftValue), snd)
+ && TryUnifyTypes(constraints, QualType(fstAndType->getRight(), fst.isLeftValue), snd);
}
else if (auto sndAndType = as<AndType>(snd))
{
- return TryUnifyTypes(constraints, fst, sndAndType->getLeft())
- || TryUnifyTypes(constraints, fst, sndAndType->getRight());
+ return TryUnifyTypes(constraints, fst, QualType(sndAndType->getLeft(), snd.isLeftValue))
+ || TryUnifyTypes(constraints, fst, QualType(sndAndType->getRight(), snd.isLeftValue));
}
else
return false;
@@ -773,8 +782,8 @@ namespace Slang
bool SemanticsVisitor::TryUnifyTypes(
ConstraintSystem& constraints,
- Type* fst,
- Type* snd)
+ QualType fst,
+ QualType snd)
{
if (!fst) return false;
@@ -843,8 +852,8 @@ namespace Slang
{
return TryUnifyTypes(
constraints,
- fstVectorType->getElementType(),
- sndScalarType);
+ QualType(fstVectorType->getElementType(), fst.isLeftValue),
+ QualType(sndScalarType, snd.isLeftValue));
}
}
@@ -854,8 +863,8 @@ namespace Slang
{
return TryUnifyTypes(
constraints,
- fstScalarType,
- sndVectorType->getElementType());
+ QualType(fstScalarType, fst.isLeftValue),
+ QualType(sndVectorType->getElementType(), snd.isLeftValue));
}
}
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index 5a9c8df12..c4efba658 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -705,7 +705,7 @@ namespace Slang
CoercionSite site,
Type* toType,
Expr** outToExpr,
- Type* fromType,
+ QualType fromType,
Expr* fromExpr,
ConversionCost* outCost)
{
@@ -773,7 +773,7 @@ namespace Slang
auto toBase = toModified ? toModified->getBase() : toType;
//
auto fromModified = as<ModifiedType>(fromType);
- auto fromBase = fromModified ? fromModified->getBase() : fromType;
+ auto fromBase = fromModified ? QualType(fromModified->getBase(), fromType.isLeftValue) : fromType;
if((toModified || fromModified) && toBase->equals(fromBase))
@@ -1060,7 +1060,7 @@ namespace Slang
OverloadResolveContext overloadContext;
overloadContext.disallowNestedConversions = true;
overloadContext.argCount = 1;
- overloadContext.argTypes = &fromType;
+ overloadContext.argTypes = &fromType.type;
overloadContext.args = &fromExpr;
overloadContext.originalExpr = nullptr;
@@ -1191,6 +1191,11 @@ namespace Slang
}
}
}
+ if (fromType.isLeftValue)
+ {
+ // If we are implicitly casting the type of an l-value, we need to impose additional cost.
+ cost += kConversionCost_LValueCast;
+ }
if(outCost)
*outCost = cost;
@@ -1245,7 +1250,7 @@ namespace Slang
bool SemanticsVisitor::canCoerce(
Type* toType,
- Type* fromType,
+ QualType fromType,
Expr* fromExpr,
ConversionCost* outCost)
{
@@ -1380,7 +1385,7 @@ namespace Slang
bool SemanticsVisitor::canConvertImplicitly(
Type* toType,
- Type* fromType)
+ QualType fromType)
{
auto conversionCost = getConversionCost(toType, fromType);
@@ -1391,7 +1396,7 @@ namespace Slang
return true;
}
- ConversionCost SemanticsVisitor::getConversionCost(Type* toType, Type* fromType)
+ ConversionCost SemanticsVisitor::getConversionCost(Type* toType, QualType fromType)
{
ConversionCost conversionCost = kConversionCost_Impossible;
if (!canCoerce(toType, fromType, nullptr, &conversionCost))
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 37dcba3f4..544bbe170 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -46,7 +46,8 @@ namespace Slang
uint32_t dim2 : 4;
uint32_t knownConstantBitCount : 8;
uint32_t knownNegative : 1;
- uint32_t reserved : 7;
+ uint32_t isLValue : 1;
+ uint32_t reserved : 6;
uint32_t getRaw() const
{
uint32_t val;
@@ -57,16 +58,16 @@ namespace Slang
{
return getRaw() == other.getRaw();
}
- static BasicTypeKey invalid() { return BasicTypeKey{ 0xff, 0, 0, 0, 0, 0 }; }
+ static BasicTypeKey invalid() { return BasicTypeKey{ 0xff, 0, 0, 0, 0, 0, 0 }; }
};
- SLANG_FORCE_INLINE BasicTypeKey makeBasicTypeKey(BaseType baseType, IntegerLiteralValue dim1 = 0, IntegerLiteralValue dim2 = 0)
+ SLANG_FORCE_INLINE BasicTypeKey makeBasicTypeKey(BaseType baseType, IntegerLiteralValue dim1 = 0, IntegerLiteralValue dim2 = 0, bool inIsLValue = false)
{
SLANG_ASSERT(dim1 >= 0 && dim2 >= 0);
- return BasicTypeKey{ uint8_t(baseType), uint8_t(dim1), uint8_t(dim2), 0, 0, 0 };
+ return BasicTypeKey{ uint8_t(baseType), uint8_t(dim1), uint8_t(dim2), 0, 0, (inIsLValue?1u:0u), 0 };
}
- inline BasicTypeKey makeBasicTypeKey(Type* typeIn, Expr* exprIn = nullptr)
+ inline BasicTypeKey makeBasicTypeKey(QualType typeIn, Expr* exprIn = nullptr)
{
if (auto basicType = as<BasicExpressionType>(typeIn))
{
@@ -79,6 +80,7 @@ namespace Slang
}
rs.knownConstantBitCount = getIntValueBitSize(constInt->value);
}
+ rs.isLValue = typeIn.isLeftValue ? 1u : 0u;
return rs;
}
else if (auto vectorType = as<VectorExpressionType>(typeIn))
@@ -87,7 +89,7 @@ namespace Slang
{
if( auto elemBasicType = as<BasicExpressionType>(vectorType->getElementType()) )
{
- return makeBasicTypeKey(elemBasicType->getBaseType(), elemCount->getValue());
+ return makeBasicTypeKey(elemBasicType->getBaseType(), elemCount->getValue(), 0, typeIn.isLeftValue);
}
}
}
@@ -99,7 +101,7 @@ namespace Slang
{
if( auto elemBasicType = as<BasicExpressionType>(matrixType->getElementType()) )
{
- return makeBasicTypeKey(elemBasicType->getBaseType(), elemCount1->getValue(), elemCount2->getValue());
+ return makeBasicTypeKey(elemBasicType->getBaseType(), elemCount1->getValue(), elemCount2->getValue(), typeIn.isLeftValue);
}
}
}
@@ -144,7 +146,7 @@ namespace Slang
for (Index i = 0; i < opExpr->arguments.getCount(); i++)
{
- auto key = makeBasicTypeKey(opExpr->arguments[i]->type.Ptr(), opExpr->arguments[i]);
+ auto key = makeBasicTypeKey(opExpr->arguments[i]->type, opExpr->arguments[i]);
if (key.getRaw() == BasicTypeKey::invalid().getRaw())
{
return false;
@@ -1352,7 +1354,7 @@ namespace Slang
CoercionSite site,
Type* toType,
Expr** outToExpr,
- Type* fromType,
+ QualType fromType,
Expr* fromExpr,
ConversionCost* outCost);
@@ -1365,7 +1367,7 @@ namespace Slang
///
bool canCoerce(
Type* toType,
- Type* fromType,
+ QualType fromType,
Expr* fromExpr,
ConversionCost* outCost = 0);
@@ -1815,6 +1817,7 @@ namespace Slang
{
Decl* decl = nullptr; // the declaration of the thing being constraints
Val* val = nullptr; // the value to which we are constraining it
+ bool isUsedAsLValue = false; // If this constraint is for a type parameter, is the type used in an l-value parameter?
bool satisfied = false; // Has this constraint been met?
};
@@ -1909,9 +1912,9 @@ namespace Slang
/// Does there exist an implicit conversion from `fromType` to `toType`?
bool canConvertImplicitly(
Type* toType,
- Type* fromType);
+ QualType fromType);
- ConversionCost getConversionCost(Type* toType, Type* fromType);
+ ConversionCost getConversionCost(Type* toType, QualType fromType);
Type* _tryJoinTypeWithInterface(
Type* type,
@@ -1919,8 +1922,8 @@ namespace Slang
// Try to compute the "join" between two types
Type* TryJoinTypes(
- Type* left,
- Type* right);
+ QualType left,
+ QualType right);
// Try to solve a system of generic constraints.
// The `system` argument provides the constraints.
@@ -1965,7 +1968,7 @@ namespace Slang
Index getArgCount() { return argCount; }
Expr*& getArg(Index index) { return args[index]; }
- Type*& getArgType(Index index)
+ Type* getArgType(Index index)
{
if(argTypes)
return argTypes[index];
@@ -2147,22 +2150,28 @@ namespace Slang
bool TryUnifyVals(
ConstraintSystem& constraints,
Val* fst,
- Val* snd);
+ bool fstLVal,
+ Val* snd,
+ bool sndLVal);
bool tryUnifyDeclRef(
ConstraintSystem& constraints,
DeclRefBase* fst,
- DeclRefBase* snd);
+ bool fstLVal,
+ DeclRefBase* snd,
+ bool sndLVal);
bool tryUnifyGenericAppDeclRef(
ConstraintSystem& constraints,
GenericAppDeclRef* fst,
- GenericAppDeclRef* snd);
+ bool fstLVal,
+ GenericAppDeclRef* snd,
+ bool sndLVal);
bool TryUnifyTypeParam(
- ConstraintSystem& constraints,
- GenericTypeParamDecl* typeParamDecl,
- Type* type);
+ ConstraintSystem& constraints,
+ GenericTypeParamDecl* typeParamDecl,
+ QualType type);
bool TryUnifyIntParam(
ConstraintSystem& constraints,
@@ -2176,18 +2185,18 @@ namespace Slang
bool TryUnifyTypesByStructuralMatch(
ConstraintSystem& constraints,
- Type* fst,
- Type* snd);
+ QualType fst,
+ QualType snd);
bool TryUnifyTypes(
ConstraintSystem& constraints,
- Type* fst,
- Type* snd);
+ QualType fst,
+ QualType snd);
bool TryUnifyConjunctionType(
ConstraintSystem& constraints,
- Type* fst,
- Type* snd);
+ QualType fst,
+ QualType snd);
// Is the candidate extension declaration actually applicable to the given type
DeclRef<ExtensionDecl> applyExtensionToType(
@@ -2204,7 +2213,7 @@ namespace Slang
DeclRef<GenericDecl> genericDeclRef,
OverloadResolveContext& context,
ArrayView<Val*> knownGenericArgs,
- List<Type*> *innerParameterTypes = nullptr);
+ List<QualType> *innerParameterTypes = nullptr);
void AddTypeOverloadCandidates(
Type* type,
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index 5e626705a..ac93d6505 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -394,21 +394,44 @@ namespace Slang
return success;
}
+ static QualType getParamQualType(ASTBuilder* astBuilder, DeclRef<ParamDecl> param)
+ {
+ auto paramType = getType(astBuilder, param);
+ bool isLVal = false;
+ switch (getParameterDirection(param.getDecl()))
+ {
+ case kParameterDirection_InOut:
+ case kParameterDirection_Out:
+ case kParameterDirection_Ref:
+ isLVal = true;
+ break;
+ }
+ return QualType(paramType, isLVal);
+ }
+
+ static QualType getParamQualType(Type* paramType)
+ {
+ if (auto paramDirType = as<ParamDirectionType>(paramType))
+ {
+ if (as<OutTypeBase>(paramDirType) || as<RefType>(paramDirType))
+ return QualType(paramDirType->getValueType(), true);
+ }
+ return paramType;
+ }
+
bool SemanticsVisitor::TryCheckOverloadCandidateTypes(
OverloadResolveContext& context,
OverloadCandidate& candidate)
{
Index argCount = context.getArgCount();
- List<Type*> paramTypes;
-// List<DeclRef<ParamDecl>> params;
+ List<QualType> paramTypes;
switch (candidate.flavor)
{
case OverloadCandidate::Flavor::Func:
for (auto param : getParameters(m_astBuilder, candidate.item.declRef.as<CallableDecl>()))
{
- auto paramType = getType(m_astBuilder, param);
- paramTypes.add(paramType);
+ paramTypes.add(getParamQualType(m_astBuilder, param));
}
break;
@@ -418,13 +441,7 @@ namespace Slang
Count paramCount = funcType->getParamCount();
for (Index i = 0; i < paramCount; ++i)
{
- auto paramType = funcType->getParamType(i);
-
- if(auto paramDirectionType = as<ParamDirectionType>(paramType))
- {
- paramType = paramDirectionType->getValueType();
- }
-
+ auto paramType = getParamQualType(funcType->getParamType(i));
paramTypes.add(paramType);
}
}
@@ -445,8 +462,8 @@ namespace Slang
for (Index ii = 0; ii < argCount; ++ii)
{
auto& arg = context.getArg(ii);
- auto argType = context.getArgType(ii);
auto paramType = paramTypes[ii];
+ auto argType = QualType(context.getArgType(ii), paramType.isLeftValue);
if (!paramType)
return false;
if (!argType)
@@ -1318,7 +1335,7 @@ namespace Slang
DeclRef<GenericDecl> genericDeclRef,
OverloadResolveContext& context,
ArrayView<Val*> knownGenericArgs,
- List<Type*> *innerParameterTypes)
+ List<QualType> *innerParameterTypes)
{
// We have been asked to infer zero or more arguments to
// `genericDeclRef`, in a context where it is being applied
@@ -1360,13 +1377,13 @@ namespace Slang
if (auto funcDeclRef = as<CallableDecl>(genericDeclRef.getDecl()->inner))
{
- List<Type*> paramTypes;
+ List<QualType> paramTypes;
if (!innerParameterTypes)
{
auto params = getParameters(m_astBuilder, funcDeclRef).toArray();
for (auto param : params)
{
- paramTypes.add(getType(m_astBuilder, param));
+ paramTypes.add(getParamQualType(m_astBuilder, param));
}
innerParameterTypes = &paramTypes;
}
@@ -1408,11 +1425,12 @@ namespace Slang
//
// So the question is then whether a mismatch during the
// unification step should be taken as an immediate failure...
-
+ auto argType = context.getArgTypeForInference(aa, this);
+ auto paramType = (*innerParameterTypes)[aa];
TryUnifyTypes(
constraints,
- context.getArgTypeForInference(aa, this),
- (*innerParameterTypes)[aa]);
+ QualType(argType, paramType.isLeftValue),
+ paramType);
}
}
else
@@ -1679,10 +1697,10 @@ namespace Slang
SLANG_ASSERT(diffFuncType);
// Extract parameter list from processed type.
- List<Type*> paramTypes;
+ List<QualType> paramTypes;
for (Index ii = 0; ii < diffFuncType->getParamCount(); ii++)
- paramTypes.add(removeParamDirType(diffFuncType->getParamType(ii)));
+ paramTypes.add(getParamQualType(diffFuncType->getParamType(ii)));
// Try to infer generic arguments, based on the updated context.
OverloadResolveContext subContext = context;