diff options
| author | Yong He <yonghe@outlook.com> | 2023-09-07 23:01:53 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-09-07 23:01:53 -0700 |
| commit | cb5dd19992fb77ca2be866d9c6f2f4436c8b1c1e (patch) | |
| tree | 4a24573f9da79618c0e65e7462101ab3d0b640c4 | |
| parent | a7fa215e81e510de34ac96778ac6320cbb642d64 (diff) | |
Incur l-value conversion cost during overload resolution. (#3195)
* Incur l-value conversion cost during overload resolution.
* Fix compile error.
* cleanup.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | source/slang/slang-ast-support-types.h | 10 | ||||
| -rw-r--r-- | source/slang/slang-check-constraint.cpp | 79 | ||||
| -rw-r--r-- | source/slang/slang-check-conversion.cpp | 17 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 65 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 58 | ||||
| -rw-r--r-- | tests/language-feature/overload-resolution.slang | 45 |
6 files changed, 185 insertions, 89 deletions
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 5d7ca49cb..0b4e9cab2 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -140,6 +140,9 @@ namespace Slang // the element type of the vector) kConversionCost_ScalarToVector = 1, + // Additional cost when casting an LValue. + kConversionCost_LValueCast = 800, + // Conversion is impossible kConversionCost_Impossible = 0xFFFFFFFF, }; @@ -521,6 +524,13 @@ namespace Slang QualType(Type* type); + QualType(Type* type, bool isLVal) + : QualType(type) + { + isLeftValue = isLVal; + } + + Type* Ptr() { return type; } operator Type*() { return type; } diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 6e600c4af..996b88f48 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -174,8 +174,8 @@ namespace Slang } Type* SemanticsVisitor::TryJoinTypes( - Type* left, - Type* right) + QualType left, + QualType right) { // Easy case: they are the same type! if (left->equals(right)) @@ -186,8 +186,8 @@ namespace Slang { if (auto rightBasic = as<BasicExpressionType>(right)) { - auto costConvertRightToLeft = getConversionCost(leftBasic, rightBasic); - auto costConvertLeftToRight = getConversionCost(rightBasic, leftBasic); + auto costConvertRightToLeft = getConversionCost(leftBasic, right); + auto costConvertLeftToRight = getConversionCost(rightBasic, left); // Return the one that had lower conversion cost. if (costConvertRightToLeft > costConvertLeftToRight) @@ -217,8 +217,8 @@ namespace Slang // Try to join the element types auto joinElementType = TryJoinTypes( - leftVector->getElementType(), - rightVector->getElementType()); + QualType(leftVector->getElementType(), left.isLeftValue), + QualType(rightVector->getElementType(), right.isLeftValue)); if(!joinElementType) return nullptr; @@ -339,13 +339,13 @@ namespace Slang continue; } - Type* type = nullptr; + QualType type; for (auto& c : system->constraints) { if (c.decl != typeParam.getDecl()) continue; - auto cType = as<Type>(c.val); + auto cType = QualType(as<Type>(c.val), c.isUsedAsLValue); SLANG_RELEASE_ASSERT(cType); if (!type) @@ -360,7 +360,7 @@ namespace Slang // failure! return DeclRef<Decl>(); } - type = joinType; + type = QualType(joinType, type.isLeftValue || cType.isLeftValue); } c.satisfied = true; @@ -505,14 +505,16 @@ namespace Slang bool SemanticsVisitor::TryUnifyVals( ConstraintSystem& constraints, Val* fst, - Val* snd) + bool fstLVal, + Val* snd, + bool sndLVal) { // if both values are types, then unify types if (auto fstType = as<Type>(fst)) { if (auto sndType = as<Type>(snd)) { - return TryUnifyTypes(constraints, fstType, sndType); + return TryUnifyTypes(constraints, QualType(fstType, fstLVal), QualType(sndType, sndLVal)); } } @@ -582,7 +584,9 @@ namespace Slang bool SemanticsVisitor::tryUnifyDeclRef( ConstraintSystem& constraints, DeclRefBase* fst, - DeclRefBase* snd) + bool fstIsLVal, + DeclRefBase* snd, + bool sndIsLVal) { if (fst == snd) return true; @@ -594,13 +598,15 @@ namespace Slang return true; if (fstGen == nullptr || sndGen == nullptr) return false; - return tryUnifyGenericAppDeclRef(constraints, fstGen, sndGen); + return tryUnifyGenericAppDeclRef(constraints, fstGen, fstIsLVal, sndGen, sndIsLVal); } bool SemanticsVisitor::tryUnifyGenericAppDeclRef( ConstraintSystem& constraints, GenericAppDeclRef* fst, - GenericAppDeclRef* snd) + bool fstIsLVal, + GenericAppDeclRef* snd, + bool sndIsLVal) { SLANG_ASSERT(fst); SLANG_ASSERT(snd); @@ -617,7 +623,7 @@ namespace Slang bool okay = true; for (Index aa = 0; aa < argCount; ++aa) { - if (!TryUnifyVals(constraints, fstGen->getArgs()[aa], sndGen->getArgs()[aa])) + if (!TryUnifyVals(constraints, fstGen->getArgs()[aa], fstIsLVal, sndGen->getArgs()[aa], sndIsLVal)) { okay = false; } @@ -627,7 +633,7 @@ namespace Slang auto fstBase = fst->getBase(); auto sndBase = snd->getBase(); - if (!tryUnifyDeclRef(constraints, fstBase, sndBase)) + if (!tryUnifyDeclRef(constraints, fstBase, fstIsLVal, sndBase, sndIsLVal)) { okay = false; } @@ -636,16 +642,16 @@ namespace Slang } bool SemanticsVisitor::TryUnifyTypeParam( - ConstraintSystem& constraints, + ConstraintSystem& constraints, GenericTypeParamDecl* typeParamDecl, - Type* type) + QualType type) { // We want to constrain the given type parameter // to equal the given type. Constraint constraint; constraint.decl = typeParamDecl; constraint.val = type; - + constraint.isUsedAsLValue = type.isLeftValue; constraints.constraints.add(constraint); return true; @@ -691,8 +697,8 @@ namespace Slang bool SemanticsVisitor::TryUnifyTypesByStructuralMatch( ConstraintSystem& constraints, - Type* fst, - Type* snd) + QualType fst, + QualType snd) { if (auto fstDeclRefType = as<DeclRefType>(fst)) { @@ -718,14 +724,17 @@ namespace Slang if (!tryUnifyDeclRef( constraints, fstDeclRef, - sndDeclRef)) + fst.isLeftValue, + sndDeclRef, + snd.isLeftValue)) { return false; } return true; } - } else if(auto fstFunType = as<FuncType>(fst)) + } + else if(auto fstFunType = as<FuncType>(fst)) { if (auto sndFunType = as<FuncType>(snd)) { @@ -746,8 +755,8 @@ namespace Slang bool SemanticsVisitor::TryUnifyConjunctionType( ConstraintSystem& constraints, - Type* fst, - Type* snd) + QualType fst, + QualType snd) { // Unifying a type `A & B` with `T` amounts to unifying // `A` with `T` and also `B` with `T` while @@ -759,13 +768,13 @@ namespace Slang // if (auto fstAndType = as<AndType>(fst)) { - return TryUnifyTypes(constraints, fstAndType->getLeft(), snd) - && TryUnifyTypes(constraints, fstAndType->getRight(), snd); + return TryUnifyTypes(constraints, QualType(fstAndType->getLeft(), fst.isLeftValue), snd) + && TryUnifyTypes(constraints, QualType(fstAndType->getRight(), fst.isLeftValue), snd); } else if (auto sndAndType = as<AndType>(snd)) { - return TryUnifyTypes(constraints, fst, sndAndType->getLeft()) - || TryUnifyTypes(constraints, fst, sndAndType->getRight()); + return TryUnifyTypes(constraints, fst, QualType(sndAndType->getLeft(), snd.isLeftValue)) + || TryUnifyTypes(constraints, fst, QualType(sndAndType->getRight(), snd.isLeftValue)); } else return false; @@ -773,8 +782,8 @@ namespace Slang bool SemanticsVisitor::TryUnifyTypes( ConstraintSystem& constraints, - Type* fst, - Type* snd) + QualType fst, + QualType snd) { if (!fst) return false; @@ -843,8 +852,8 @@ namespace Slang { return TryUnifyTypes( constraints, - fstVectorType->getElementType(), - sndScalarType); + QualType(fstVectorType->getElementType(), fst.isLeftValue), + QualType(sndScalarType, snd.isLeftValue)); } } @@ -854,8 +863,8 @@ namespace Slang { return TryUnifyTypes( constraints, - fstScalarType, - sndVectorType->getElementType()); + QualType(fstScalarType, fst.isLeftValue), + QualType(sndVectorType->getElementType(), snd.isLeftValue)); } } diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 5a9c8df12..c4efba658 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -705,7 +705,7 @@ namespace Slang CoercionSite site, Type* toType, Expr** outToExpr, - Type* fromType, + QualType fromType, Expr* fromExpr, ConversionCost* outCost) { @@ -773,7 +773,7 @@ namespace Slang auto toBase = toModified ? toModified->getBase() : toType; // auto fromModified = as<ModifiedType>(fromType); - auto fromBase = fromModified ? fromModified->getBase() : fromType; + auto fromBase = fromModified ? QualType(fromModified->getBase(), fromType.isLeftValue) : fromType; if((toModified || fromModified) && toBase->equals(fromBase)) @@ -1060,7 +1060,7 @@ namespace Slang OverloadResolveContext overloadContext; overloadContext.disallowNestedConversions = true; overloadContext.argCount = 1; - overloadContext.argTypes = &fromType; + overloadContext.argTypes = &fromType.type; overloadContext.args = &fromExpr; overloadContext.originalExpr = nullptr; @@ -1191,6 +1191,11 @@ namespace Slang } } } + if (fromType.isLeftValue) + { + // If we are implicitly casting the type of an l-value, we need to impose additional cost. + cost += kConversionCost_LValueCast; + } if(outCost) *outCost = cost; @@ -1245,7 +1250,7 @@ namespace Slang bool SemanticsVisitor::canCoerce( Type* toType, - Type* fromType, + QualType fromType, Expr* fromExpr, ConversionCost* outCost) { @@ -1380,7 +1385,7 @@ namespace Slang bool SemanticsVisitor::canConvertImplicitly( Type* toType, - Type* fromType) + QualType fromType) { auto conversionCost = getConversionCost(toType, fromType); @@ -1391,7 +1396,7 @@ namespace Slang return true; } - ConversionCost SemanticsVisitor::getConversionCost(Type* toType, Type* fromType) + ConversionCost SemanticsVisitor::getConversionCost(Type* toType, QualType fromType) { ConversionCost conversionCost = kConversionCost_Impossible; if (!canCoerce(toType, fromType, nullptr, &conversionCost)) diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 37dcba3f4..544bbe170 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -46,7 +46,8 @@ namespace Slang uint32_t dim2 : 4; uint32_t knownConstantBitCount : 8; uint32_t knownNegative : 1; - uint32_t reserved : 7; + uint32_t isLValue : 1; + uint32_t reserved : 6; uint32_t getRaw() const { uint32_t val; @@ -57,16 +58,16 @@ namespace Slang { return getRaw() == other.getRaw(); } - static BasicTypeKey invalid() { return BasicTypeKey{ 0xff, 0, 0, 0, 0, 0 }; } + static BasicTypeKey invalid() { return BasicTypeKey{ 0xff, 0, 0, 0, 0, 0, 0 }; } }; - SLANG_FORCE_INLINE BasicTypeKey makeBasicTypeKey(BaseType baseType, IntegerLiteralValue dim1 = 0, IntegerLiteralValue dim2 = 0) + SLANG_FORCE_INLINE BasicTypeKey makeBasicTypeKey(BaseType baseType, IntegerLiteralValue dim1 = 0, IntegerLiteralValue dim2 = 0, bool inIsLValue = false) { SLANG_ASSERT(dim1 >= 0 && dim2 >= 0); - return BasicTypeKey{ uint8_t(baseType), uint8_t(dim1), uint8_t(dim2), 0, 0, 0 }; + return BasicTypeKey{ uint8_t(baseType), uint8_t(dim1), uint8_t(dim2), 0, 0, (inIsLValue?1u:0u), 0 }; } - inline BasicTypeKey makeBasicTypeKey(Type* typeIn, Expr* exprIn = nullptr) + inline BasicTypeKey makeBasicTypeKey(QualType typeIn, Expr* exprIn = nullptr) { if (auto basicType = as<BasicExpressionType>(typeIn)) { @@ -79,6 +80,7 @@ namespace Slang } rs.knownConstantBitCount = getIntValueBitSize(constInt->value); } + rs.isLValue = typeIn.isLeftValue ? 1u : 0u; return rs; } else if (auto vectorType = as<VectorExpressionType>(typeIn)) @@ -87,7 +89,7 @@ namespace Slang { if( auto elemBasicType = as<BasicExpressionType>(vectorType->getElementType()) ) { - return makeBasicTypeKey(elemBasicType->getBaseType(), elemCount->getValue()); + return makeBasicTypeKey(elemBasicType->getBaseType(), elemCount->getValue(), 0, typeIn.isLeftValue); } } } @@ -99,7 +101,7 @@ namespace Slang { if( auto elemBasicType = as<BasicExpressionType>(matrixType->getElementType()) ) { - return makeBasicTypeKey(elemBasicType->getBaseType(), elemCount1->getValue(), elemCount2->getValue()); + return makeBasicTypeKey(elemBasicType->getBaseType(), elemCount1->getValue(), elemCount2->getValue(), typeIn.isLeftValue); } } } @@ -144,7 +146,7 @@ namespace Slang for (Index i = 0; i < opExpr->arguments.getCount(); i++) { - auto key = makeBasicTypeKey(opExpr->arguments[i]->type.Ptr(), opExpr->arguments[i]); + auto key = makeBasicTypeKey(opExpr->arguments[i]->type, opExpr->arguments[i]); if (key.getRaw() == BasicTypeKey::invalid().getRaw()) { return false; @@ -1352,7 +1354,7 @@ namespace Slang CoercionSite site, Type* toType, Expr** outToExpr, - Type* fromType, + QualType fromType, Expr* fromExpr, ConversionCost* outCost); @@ -1365,7 +1367,7 @@ namespace Slang /// bool canCoerce( Type* toType, - Type* fromType, + QualType fromType, Expr* fromExpr, ConversionCost* outCost = 0); @@ -1815,6 +1817,7 @@ namespace Slang { Decl* decl = nullptr; // the declaration of the thing being constraints Val* val = nullptr; // the value to which we are constraining it + bool isUsedAsLValue = false; // If this constraint is for a type parameter, is the type used in an l-value parameter? bool satisfied = false; // Has this constraint been met? }; @@ -1909,9 +1912,9 @@ namespace Slang /// Does there exist an implicit conversion from `fromType` to `toType`? bool canConvertImplicitly( Type* toType, - Type* fromType); + QualType fromType); - ConversionCost getConversionCost(Type* toType, Type* fromType); + ConversionCost getConversionCost(Type* toType, QualType fromType); Type* _tryJoinTypeWithInterface( Type* type, @@ -1919,8 +1922,8 @@ namespace Slang // Try to compute the "join" between two types Type* TryJoinTypes( - Type* left, - Type* right); + QualType left, + QualType right); // Try to solve a system of generic constraints. // The `system` argument provides the constraints. @@ -1965,7 +1968,7 @@ namespace Slang Index getArgCount() { return argCount; } Expr*& getArg(Index index) { return args[index]; } - Type*& getArgType(Index index) + Type* getArgType(Index index) { if(argTypes) return argTypes[index]; @@ -2147,22 +2150,28 @@ namespace Slang bool TryUnifyVals( ConstraintSystem& constraints, Val* fst, - Val* snd); + bool fstLVal, + Val* snd, + bool sndLVal); bool tryUnifyDeclRef( ConstraintSystem& constraints, DeclRefBase* fst, - DeclRefBase* snd); + bool fstLVal, + DeclRefBase* snd, + bool sndLVal); bool tryUnifyGenericAppDeclRef( ConstraintSystem& constraints, GenericAppDeclRef* fst, - GenericAppDeclRef* snd); + bool fstLVal, + GenericAppDeclRef* snd, + bool sndLVal); bool TryUnifyTypeParam( - ConstraintSystem& constraints, - GenericTypeParamDecl* typeParamDecl, - Type* type); + ConstraintSystem& constraints, + GenericTypeParamDecl* typeParamDecl, + QualType type); bool TryUnifyIntParam( ConstraintSystem& constraints, @@ -2176,18 +2185,18 @@ namespace Slang bool TryUnifyTypesByStructuralMatch( ConstraintSystem& constraints, - Type* fst, - Type* snd); + QualType fst, + QualType snd); bool TryUnifyTypes( ConstraintSystem& constraints, - Type* fst, - Type* snd); + QualType fst, + QualType snd); bool TryUnifyConjunctionType( ConstraintSystem& constraints, - Type* fst, - Type* snd); + QualType fst, + QualType snd); // Is the candidate extension declaration actually applicable to the given type DeclRef<ExtensionDecl> applyExtensionToType( @@ -2204,7 +2213,7 @@ namespace Slang DeclRef<GenericDecl> genericDeclRef, OverloadResolveContext& context, ArrayView<Val*> knownGenericArgs, - List<Type*> *innerParameterTypes = nullptr); + List<QualType> *innerParameterTypes = nullptr); void AddTypeOverloadCandidates( Type* type, diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 5e626705a..ac93d6505 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -394,21 +394,44 @@ namespace Slang return success; } + static QualType getParamQualType(ASTBuilder* astBuilder, DeclRef<ParamDecl> param) + { + auto paramType = getType(astBuilder, param); + bool isLVal = false; + switch (getParameterDirection(param.getDecl())) + { + case kParameterDirection_InOut: + case kParameterDirection_Out: + case kParameterDirection_Ref: + isLVal = true; + break; + } + return QualType(paramType, isLVal); + } + + static QualType getParamQualType(Type* paramType) + { + if (auto paramDirType = as<ParamDirectionType>(paramType)) + { + if (as<OutTypeBase>(paramDirType) || as<RefType>(paramDirType)) + return QualType(paramDirType->getValueType(), true); + } + return paramType; + } + bool SemanticsVisitor::TryCheckOverloadCandidateTypes( OverloadResolveContext& context, OverloadCandidate& candidate) { Index argCount = context.getArgCount(); - List<Type*> paramTypes; -// List<DeclRef<ParamDecl>> params; + List<QualType> paramTypes; switch (candidate.flavor) { case OverloadCandidate::Flavor::Func: for (auto param : getParameters(m_astBuilder, candidate.item.declRef.as<CallableDecl>())) { - auto paramType = getType(m_astBuilder, param); - paramTypes.add(paramType); + paramTypes.add(getParamQualType(m_astBuilder, param)); } break; @@ -418,13 +441,7 @@ namespace Slang Count paramCount = funcType->getParamCount(); for (Index i = 0; i < paramCount; ++i) { - auto paramType = funcType->getParamType(i); - - if(auto paramDirectionType = as<ParamDirectionType>(paramType)) - { - paramType = paramDirectionType->getValueType(); - } - + auto paramType = getParamQualType(funcType->getParamType(i)); paramTypes.add(paramType); } } @@ -445,8 +462,8 @@ namespace Slang for (Index ii = 0; ii < argCount; ++ii) { auto& arg = context.getArg(ii); - auto argType = context.getArgType(ii); auto paramType = paramTypes[ii]; + auto argType = QualType(context.getArgType(ii), paramType.isLeftValue); if (!paramType) return false; if (!argType) @@ -1318,7 +1335,7 @@ namespace Slang DeclRef<GenericDecl> genericDeclRef, OverloadResolveContext& context, ArrayView<Val*> knownGenericArgs, - List<Type*> *innerParameterTypes) + List<QualType> *innerParameterTypes) { // We have been asked to infer zero or more arguments to // `genericDeclRef`, in a context where it is being applied @@ -1360,13 +1377,13 @@ namespace Slang if (auto funcDeclRef = as<CallableDecl>(genericDeclRef.getDecl()->inner)) { - List<Type*> paramTypes; + List<QualType> paramTypes; if (!innerParameterTypes) { auto params = getParameters(m_astBuilder, funcDeclRef).toArray(); for (auto param : params) { - paramTypes.add(getType(m_astBuilder, param)); + paramTypes.add(getParamQualType(m_astBuilder, param)); } innerParameterTypes = ¶mTypes; } @@ -1408,11 +1425,12 @@ namespace Slang // // So the question is then whether a mismatch during the // unification step should be taken as an immediate failure... - + auto argType = context.getArgTypeForInference(aa, this); + auto paramType = (*innerParameterTypes)[aa]; TryUnifyTypes( constraints, - context.getArgTypeForInference(aa, this), - (*innerParameterTypes)[aa]); + QualType(argType, paramType.isLeftValue), + paramType); } } else @@ -1679,10 +1697,10 @@ namespace Slang SLANG_ASSERT(diffFuncType); // Extract parameter list from processed type. - List<Type*> paramTypes; + List<QualType> paramTypes; for (Index ii = 0; ii < diffFuncType->getParamCount(); ii++) - paramTypes.add(removeParamDirType(diffFuncType->getParamType(ii))); + paramTypes.add(getParamQualType(diffFuncType->getParamType(ii))); // Try to infer generic arguments, based on the updated context. OverloadResolveContext subContext = context; diff --git a/tests/language-feature/overload-resolution.slang b/tests/language-feature/overload-resolution.slang new file mode 100644 index 000000000..9c135137a --- /dev/null +++ b/tests/language-feature/overload-resolution.slang @@ -0,0 +1,45 @@ +//TEST:SIMPLE(filecheck=CHECK): -target hlsl -stage compute -entry main +RWStructuredBuffer<float> result; + +[ForceInline] +float myF(inout int a, int b) +{ + return a + b; +} + +[ForceInline] +float myF(inout uint a, uint b) +{ + return a - b; +} + +[ForceInline] +T myGenF<T : __BuiltinIntegerType>(inout T a, T b) +{ + if (__isSignedInt<T>()) + { + return a + b; + } + else + { + return a - b; + } +} +// CHECK: result{{.*}}[0{{U?}}] = 1 +// CHECK: result{{.*}}[1{{U?}}] = 4 +// CHECK: result{{.*}}[2{{U?}}] = 1 +// CHECK: result{{.*}}[3{{U?}}] = 4 +[numthreads(1,1,1)] +void main() +{ + int ic = 1; + uint a = 2; + result[0] = myF(a, ic); + + int b = 3; + uint uc = 1; + result[1] = myF(b, uc); + + result[2] = myGenF(a, ic); + result[3] = myGenF(b, uc); +}
\ No newline at end of file |
