diff options
| -rw-r--r-- | source/slang/check.cpp | 279 | ||||
| -rw-r--r-- | source/slang/compiler.h | 5 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 2 | ||||
| -rw-r--r-- | source/slang/syntax.h | 3 |
4 files changed, 239 insertions, 50 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp index 6f9aa8adc..a94812687 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -9,6 +9,174 @@ namespace Slang { + // A flat representation of basic types (scalars, vectors and matrices) + // that can be used as lookup key in caches + struct BasicTypeKey + { + union + { + struct + { + unsigned char type : 4; + unsigned char dim1 : 2; + unsigned char dim2 : 2; + } data; + unsigned char aggVal; + }; + bool fromType(Type* typeIn) + { + aggVal = 0; + if (auto basicType = typeIn->AsBasicType()) + { + data.type = (unsigned char)basicType->baseType; + data.dim1 = data.dim2 = 0; + } + else if (auto vectorType = typeIn->AsVectorType()) + { + if (auto elemCount = vectorType->elementCount.As<ConstantIntVal>()) + { + data.dim1 = elemCount->value - 1; + data.type = (unsigned char)vectorType->elementType->AsBasicType()->baseType; + data.dim2 = 0; + } + else + return false; + } + else if (auto matrixType = typeIn->AsMatrixType()) + { + if (auto elemCount1 = dynamic_cast<ConstantIntVal*>(matrixType->getRowCount())) + { + if (auto elemCount2 = dynamic_cast<ConstantIntVal*>(matrixType->getColumnCount())) + { + data.type = (unsigned char)matrixType->getElementType()->AsBasicType()->baseType; + data.dim1 = elemCount1->value - 1; + data.dim2 = elemCount2->value - 1; + } + } + else + return false; + } + else + return false; + return true; + } + }; + + struct BasicTypeKeyPair + { + BasicTypeKey type1, type2; + bool operator == (BasicTypeKeyPair p) + { + return type1.aggVal == p.type1.aggVal && type2.aggVal == p.type2.aggVal; + } + int GetHashCode() + { + return combineHash(type1.aggVal, type2.aggVal); + } + }; + + struct OverloadCandidate + { + enum class Flavor + { + Func, + Generic, + UnspecializedGeneric, + }; + Flavor flavor; + + enum class Status + { + GenericArgumentInferenceFailed, + Unchecked, + ArityChecked, + FixityChecked, + TypeChecked, + DirectionChecked, + Appicable, + }; + Status status = Status::Unchecked; + + // Reference to the declaration being applied + LookupResultItem item; + + // The type of the result expression if this candidate is selected + RefPtr<Type> resultType; + + // A system for tracking constraints introduced on generic parameters + // ConstraintSystem constraintSystem; + + // How much conversion cost should be considered for this overload, + // when ranking candidates. + ConversionCost conversionCostSum = kConversionCost_None; + + // When required, a candidate can store a pre-checked list of + // arguments so that we don't have to repeat work across checking + // phases. Currently this is only needed for generics. + RefPtr<Substitutions> subst; + }; + + struct OperatorOverloadCacheKey + { + IROp operatorName; + BasicTypeKey args[2]; + bool operator == (OperatorOverloadCacheKey key) + { + return operatorName == key.operatorName && args[0].aggVal == key.args[0].aggVal + && args[1].aggVal == key.args[1].aggVal; + } + int GetHashCode() + { + return ((int)(UInt64)(void*)(operatorName) << 16) ^ (args[0].aggVal << 8) ^ (args[1].aggVal); + } + bool fromOperatorExpr(OperatorExpr* opExpr) + { + args[0].aggVal = 0; + args[1].aggVal = 0; + if (opExpr->Arguments.Count() > 2) + return false; + if (auto overloadedBase = opExpr->FunctionExpr->As<OverloadedExpr>()) + { + Decl* funcDecl = overloadedBase->lookupResult2.item.declRef.decl; + if (auto genDecl = funcDecl->As<GenericDecl>()) + funcDecl = genDecl->inner.Ptr(); + if (auto intrinsicOp = funcDecl->FindModifier<IntrinsicOpModifier>()) + { + operatorName = intrinsicOp->op; + for (UInt i = 0; i < opExpr->Arguments.Count(); i++) + { + if (!args[i].fromType(opExpr->Arguments[i]->type.Ptr())) + return false; + } + } + else + { + return false; + } + } + return true; + } + }; + + struct TypeCheckingCache + { + Dictionary<OperatorOverloadCacheKey, OverloadCandidate> resolvedOperatorOverloadCache; + Dictionary<BasicTypeKeyPair, ConversionCost> conversionCostCache; + }; + + TypeCheckingCache* Session::getTypeCheckingCache() + { + if (!typeCheckingCache) + typeCheckingCache = new TypeCheckingCache(); + return typeCheckingCache; + } + + void Session::destroyTypeCheckingCache() + { + delete typeCheckingCache; + typeCheckingCache = nullptr; + } + bool IsNumeric(BaseType t) { return t == BaseType::Int || t == BaseType::Float || t == BaseType::UInt; @@ -1233,12 +1401,40 @@ namespace Slang RefPtr<Type> fromType, // the source type for the conversion ConversionCost* outCost = 0) // (optional) a place to stuff the conversion cost { - return TryCoerceImpl( + BasicTypeKey key1, key2; + BasicTypeKeyPair cacheKey; + bool shouldAddToCache = false; + ConversionCost cost; + TypeCheckingCache* typeCheckingCache = getSession()->getTypeCheckingCache(); + if (key1.fromType(toType.Ptr()) && key2.fromType(fromType.Ptr())) + { + cacheKey.type1 = key1; + cacheKey.type2 = key2; + + if (typeCheckingCache->conversionCostCache.TryGetValue(cacheKey, cost)) + { + if (outCost) + *outCost = cost; + return cost != kConversionCost_Impossible; + } + else + shouldAddToCache = true; + } + bool rs = TryCoerceImpl( toType, nullptr, fromType, nullptr, - outCost); + &cost); + if (outCost) + *outCost = cost; + if (shouldAddToCache) + { + if (!rs) + cost = kConversionCost_Impossible; + typeCheckingCache->conversionCostCache[cacheKey] = cost; + } + return rs; } RefPtr<TypeCastExpr> createImplicitCastExpr() @@ -4803,51 +4999,7 @@ namespace Slang return resultSubst; } - // - - struct OverloadCandidate - { - enum class Flavor - { - Func, - Generic, - UnspecializedGeneric, - }; - Flavor flavor; - - enum class Status - { - GenericArgumentInferenceFailed, - Unchecked, - ArityChecked, - FixityChecked, - TypeChecked, - DirectionChecked, - Appicable, - }; - Status status = Status::Unchecked; - - // Reference to the declaration being applied - LookupResultItem item; - - // The type of the result expression if this candidate is selected - RefPtr<Type> resultType; - - // A system for tracking constraints introduced on generic parameters -// ConstraintSystem constraintSystem; - - // How much conversion cost should be considered for this overload, - // when ranking candidates. - ConversionCost conversionCostSum = kConversionCost_None; - - // When required, a candidate can store a pre-checked list of - // arguments so that we don't have to repeat work across checking - // phases. Currently this is only needed for generics. - RefPtr<Substitutions> subst; - }; - - - + // State related to overload resolution for a call // to an overloaded symbol struct OverloadResolveContext @@ -6522,6 +6674,29 @@ namespace Slang RefPtr<Expr> ResolveInvoke(InvokeExpr * expr) { + OverloadResolveContext context; + // check if this is a stdlib operator call, if so we want to use cached results + // to speed up compilation + bool shouldAddToCache = false; + OperatorOverloadCacheKey key; + TypeCheckingCache* typeCheckingCache = getSession()->getTypeCheckingCache(); + if (auto opExpr = expr->As<OperatorExpr>()) + { + if (key.fromOperatorExpr(opExpr)) + { + OverloadCandidate candidate; + if (typeCheckingCache->resolvedOperatorOverloadCache.TryGetValue(key, candidate)) + { + context.bestCandidateStorage = candidate; + context.bestCandidate = &context.bestCandidateStorage; + } + else + { + shouldAddToCache = true; + } + } + } + // Look at the base expression for the call, and figure out how to invoke it. auto funcExpr = expr->FunctionExpr; auto funcExprType = funcExpr->type; @@ -6540,8 +6715,6 @@ namespace Slang return CreateErrorExpr(expr); } - OverloadResolveContext context; - context.originalExpr = expr; context.funcLoc = funcExpr->loc; @@ -6561,7 +6734,11 @@ namespace Slang { context.baseExpr = funcOverloadExpr2->base; } - AddOverloadCandidates(funcExpr, context); + + if (!context.bestCandidate) + { + AddOverloadCandidates(funcExpr, context); + } if (context.bestCandidates.Count() > 0) { @@ -6660,6 +6837,8 @@ namespace Slang // applicable in the end. // We will report errors for this one candidate, then, to give // the user the most help we can. + if (shouldAddToCache) + typeCheckingCache->resolvedOperatorOverloadCache[key] = *context.bestCandidate; return CompleteOverloadCandidate(context, *context.bestCandidate); } else diff --git a/source/slang/compiler.h b/source/slang/compiler.h index 4cda366f0..e7c40bdc8 100644 --- a/source/slang/compiler.h +++ b/source/slang/compiler.h @@ -445,6 +445,7 @@ namespace Slang char const* text, CodeGenTarget target); + struct TypeCheckingCache; // class Session @@ -535,6 +536,10 @@ namespace Slang Dictionary<Name*, SyntaxClass<RefObject> > mapNameToSyntaxClass; + // cache used by type checking, implemented in check.cpp + TypeCheckingCache* typeCheckingCache = nullptr; + TypeCheckingCache* getTypeCheckingCache(); + void destroyTypeCheckingCache(); // Session(); diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 9e740d5f5..3cb580f00 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -823,6 +823,8 @@ Session::~Session() irBasicBlockType = nullptr; constExprRate = nullptr; + destroyTypeCheckingCache(); + builtinTypes = decltype(builtinTypes)(); // destroy modules next loadedModuleCode = decltype(loadedModuleCode)(); diff --git a/source/slang/syntax.h b/source/slang/syntax.h index ebb9d814b..00f7eb95b 100644 --- a/source/slang/syntax.h +++ b/source/slang/syntax.h @@ -77,6 +77,9 @@ namespace Slang // a vector (this will be added to the cost, if any, of converting // the element type of the vector) kConversionCost_ScalarToVector = 1, + + // Conversion is impossible + kConversionCost_Impossible = 0xFFFFFFFF, }; // TODO(tfoley): We should ditch this enumeration |
