summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2018-05-02 15:11:04 -0400
committerYong He <yonghe@outlook.com>2018-05-02 18:20:32 -0400
commit0399d992e21128a2c4b676e8f5456981ccfa6469 (patch)
tree8ca5c144c438d95406eba5b41be3a3e689beb36a /source/slang
parent60bcc6809f57e12f3705cc65cb325b0983b08899 (diff)
Speedup type checking using cached overload resolution results.
This change adds caches to built-in operator overload resolution and type coersion to avoid running these time-consuming operations every time. - Adds `TypeCheckingCache` type, which is defined in check.cpp, that contains two dictionaries for the cached results of `ResolveInvoke` and `CanCoerce` calls. - Add `destroyTypeCheckingCache` and `getTypeCheckingCache` methods to `Session` class to reuse these cached results over the entire session.
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/check.cpp279
-rw-r--r--source/slang/compiler.h5
-rw-r--r--source/slang/slang.cpp2
-rw-r--r--source/slang/syntax.h3
4 files changed, 239 insertions, 50 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp
index 6f9aa8adc..a94812687 100644
--- a/source/slang/check.cpp
+++ b/source/slang/check.cpp
@@ -9,6 +9,174 @@
namespace Slang
{
+ // A flat representation of basic types (scalars, vectors and matrices)
+ // that can be used as lookup key in caches
+ struct BasicTypeKey
+ {
+ union
+ {
+ struct
+ {
+ unsigned char type : 4;
+ unsigned char dim1 : 2;
+ unsigned char dim2 : 2;
+ } data;
+ unsigned char aggVal;
+ };
+ bool fromType(Type* typeIn)
+ {
+ aggVal = 0;
+ if (auto basicType = typeIn->AsBasicType())
+ {
+ data.type = (unsigned char)basicType->baseType;
+ data.dim1 = data.dim2 = 0;
+ }
+ else if (auto vectorType = typeIn->AsVectorType())
+ {
+ if (auto elemCount = vectorType->elementCount.As<ConstantIntVal>())
+ {
+ data.dim1 = elemCount->value - 1;
+ data.type = (unsigned char)vectorType->elementType->AsBasicType()->baseType;
+ data.dim2 = 0;
+ }
+ else
+ return false;
+ }
+ else if (auto matrixType = typeIn->AsMatrixType())
+ {
+ if (auto elemCount1 = dynamic_cast<ConstantIntVal*>(matrixType->getRowCount()))
+ {
+ if (auto elemCount2 = dynamic_cast<ConstantIntVal*>(matrixType->getColumnCount()))
+ {
+ data.type = (unsigned char)matrixType->getElementType()->AsBasicType()->baseType;
+ data.dim1 = elemCount1->value - 1;
+ data.dim2 = elemCount2->value - 1;
+ }
+ }
+ else
+ return false;
+ }
+ else
+ return false;
+ return true;
+ }
+ };
+
+ struct BasicTypeKeyPair
+ {
+ BasicTypeKey type1, type2;
+ bool operator == (BasicTypeKeyPair p)
+ {
+ return type1.aggVal == p.type1.aggVal && type2.aggVal == p.type2.aggVal;
+ }
+ int GetHashCode()
+ {
+ return combineHash(type1.aggVal, type2.aggVal);
+ }
+ };
+
+ struct OverloadCandidate
+ {
+ enum class Flavor
+ {
+ Func,
+ Generic,
+ UnspecializedGeneric,
+ };
+ Flavor flavor;
+
+ enum class Status
+ {
+ GenericArgumentInferenceFailed,
+ Unchecked,
+ ArityChecked,
+ FixityChecked,
+ TypeChecked,
+ DirectionChecked,
+ Appicable,
+ };
+ Status status = Status::Unchecked;
+
+ // Reference to the declaration being applied
+ LookupResultItem item;
+
+ // The type of the result expression if this candidate is selected
+ RefPtr<Type> resultType;
+
+ // A system for tracking constraints introduced on generic parameters
+ // ConstraintSystem constraintSystem;
+
+ // How much conversion cost should be considered for this overload,
+ // when ranking candidates.
+ ConversionCost conversionCostSum = kConversionCost_None;
+
+ // When required, a candidate can store a pre-checked list of
+ // arguments so that we don't have to repeat work across checking
+ // phases. Currently this is only needed for generics.
+ RefPtr<Substitutions> subst;
+ };
+
+ struct OperatorOverloadCacheKey
+ {
+ IROp operatorName;
+ BasicTypeKey args[2];
+ bool operator == (OperatorOverloadCacheKey key)
+ {
+ return operatorName == key.operatorName && args[0].aggVal == key.args[0].aggVal
+ && args[1].aggVal == key.args[1].aggVal;
+ }
+ int GetHashCode()
+ {
+ return ((int)(UInt64)(void*)(operatorName) << 16) ^ (args[0].aggVal << 8) ^ (args[1].aggVal);
+ }
+ bool fromOperatorExpr(OperatorExpr* opExpr)
+ {
+ args[0].aggVal = 0;
+ args[1].aggVal = 0;
+ if (opExpr->Arguments.Count() > 2)
+ return false;
+ if (auto overloadedBase = opExpr->FunctionExpr->As<OverloadedExpr>())
+ {
+ Decl* funcDecl = overloadedBase->lookupResult2.item.declRef.decl;
+ if (auto genDecl = funcDecl->As<GenericDecl>())
+ funcDecl = genDecl->inner.Ptr();
+ if (auto intrinsicOp = funcDecl->FindModifier<IntrinsicOpModifier>())
+ {
+ operatorName = intrinsicOp->op;
+ for (UInt i = 0; i < opExpr->Arguments.Count(); i++)
+ {
+ if (!args[i].fromType(opExpr->Arguments[i]->type.Ptr()))
+ return false;
+ }
+ }
+ else
+ {
+ return false;
+ }
+ }
+ return true;
+ }
+ };
+
+ struct TypeCheckingCache
+ {
+ Dictionary<OperatorOverloadCacheKey, OverloadCandidate> resolvedOperatorOverloadCache;
+ Dictionary<BasicTypeKeyPair, ConversionCost> conversionCostCache;
+ };
+
+ TypeCheckingCache* Session::getTypeCheckingCache()
+ {
+ if (!typeCheckingCache)
+ typeCheckingCache = new TypeCheckingCache();
+ return typeCheckingCache;
+ }
+
+ void Session::destroyTypeCheckingCache()
+ {
+ delete typeCheckingCache;
+ typeCheckingCache = nullptr;
+ }
+
bool IsNumeric(BaseType t)
{
return t == BaseType::Int || t == BaseType::Float || t == BaseType::UInt;
@@ -1233,12 +1401,40 @@ namespace Slang
RefPtr<Type> fromType, // the source type for the conversion
ConversionCost* outCost = 0) // (optional) a place to stuff the conversion cost
{
- return TryCoerceImpl(
+ BasicTypeKey key1, key2;
+ BasicTypeKeyPair cacheKey;
+ bool shouldAddToCache = false;
+ ConversionCost cost;
+ TypeCheckingCache* typeCheckingCache = getSession()->getTypeCheckingCache();
+ if (key1.fromType(toType.Ptr()) && key2.fromType(fromType.Ptr()))
+ {
+ cacheKey.type1 = key1;
+ cacheKey.type2 = key2;
+
+ if (typeCheckingCache->conversionCostCache.TryGetValue(cacheKey, cost))
+ {
+ if (outCost)
+ *outCost = cost;
+ return cost != kConversionCost_Impossible;
+ }
+ else
+ shouldAddToCache = true;
+ }
+ bool rs = TryCoerceImpl(
toType,
nullptr,
fromType,
nullptr,
- outCost);
+ &cost);
+ if (outCost)
+ *outCost = cost;
+ if (shouldAddToCache)
+ {
+ if (!rs)
+ cost = kConversionCost_Impossible;
+ typeCheckingCache->conversionCostCache[cacheKey] = cost;
+ }
+ return rs;
}
RefPtr<TypeCastExpr> createImplicitCastExpr()
@@ -4803,51 +4999,7 @@ namespace Slang
return resultSubst;
}
- //
-
- struct OverloadCandidate
- {
- enum class Flavor
- {
- Func,
- Generic,
- UnspecializedGeneric,
- };
- Flavor flavor;
-
- enum class Status
- {
- GenericArgumentInferenceFailed,
- Unchecked,
- ArityChecked,
- FixityChecked,
- TypeChecked,
- DirectionChecked,
- Appicable,
- };
- Status status = Status::Unchecked;
-
- // Reference to the declaration being applied
- LookupResultItem item;
-
- // The type of the result expression if this candidate is selected
- RefPtr<Type> resultType;
-
- // A system for tracking constraints introduced on generic parameters
-// ConstraintSystem constraintSystem;
-
- // How much conversion cost should be considered for this overload,
- // when ranking candidates.
- ConversionCost conversionCostSum = kConversionCost_None;
-
- // When required, a candidate can store a pre-checked list of
- // arguments so that we don't have to repeat work across checking
- // phases. Currently this is only needed for generics.
- RefPtr<Substitutions> subst;
- };
-
-
-
+
// State related to overload resolution for a call
// to an overloaded symbol
struct OverloadResolveContext
@@ -6522,6 +6674,29 @@ namespace Slang
RefPtr<Expr> ResolveInvoke(InvokeExpr * expr)
{
+ OverloadResolveContext context;
+ // check if this is a stdlib operator call, if so we want to use cached results
+ // to speed up compilation
+ bool shouldAddToCache = false;
+ OperatorOverloadCacheKey key;
+ TypeCheckingCache* typeCheckingCache = getSession()->getTypeCheckingCache();
+ if (auto opExpr = expr->As<OperatorExpr>())
+ {
+ if (key.fromOperatorExpr(opExpr))
+ {
+ OverloadCandidate candidate;
+ if (typeCheckingCache->resolvedOperatorOverloadCache.TryGetValue(key, candidate))
+ {
+ context.bestCandidateStorage = candidate;
+ context.bestCandidate = &context.bestCandidateStorage;
+ }
+ else
+ {
+ shouldAddToCache = true;
+ }
+ }
+ }
+
// Look at the base expression for the call, and figure out how to invoke it.
auto funcExpr = expr->FunctionExpr;
auto funcExprType = funcExpr->type;
@@ -6540,8 +6715,6 @@ namespace Slang
return CreateErrorExpr(expr);
}
- OverloadResolveContext context;
-
context.originalExpr = expr;
context.funcLoc = funcExpr->loc;
@@ -6561,7 +6734,11 @@ namespace Slang
{
context.baseExpr = funcOverloadExpr2->base;
}
- AddOverloadCandidates(funcExpr, context);
+
+ if (!context.bestCandidate)
+ {
+ AddOverloadCandidates(funcExpr, context);
+ }
if (context.bestCandidates.Count() > 0)
{
@@ -6660,6 +6837,8 @@ namespace Slang
// applicable in the end.
// We will report errors for this one candidate, then, to give
// the user the most help we can.
+ if (shouldAddToCache)
+ typeCheckingCache->resolvedOperatorOverloadCache[key] = *context.bestCandidate;
return CompleteOverloadCandidate(context, *context.bestCandidate);
}
else
diff --git a/source/slang/compiler.h b/source/slang/compiler.h
index 4cda366f0..e7c40bdc8 100644
--- a/source/slang/compiler.h
+++ b/source/slang/compiler.h
@@ -445,6 +445,7 @@ namespace Slang
char const* text,
CodeGenTarget target);
+ struct TypeCheckingCache;
//
class Session
@@ -535,6 +536,10 @@ namespace Slang
Dictionary<Name*, SyntaxClass<RefObject> > mapNameToSyntaxClass;
+ // cache used by type checking, implemented in check.cpp
+ TypeCheckingCache* typeCheckingCache = nullptr;
+ TypeCheckingCache* getTypeCheckingCache();
+ void destroyTypeCheckingCache();
//
Session();
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 9e740d5f5..3cb580f00 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -823,6 +823,8 @@ Session::~Session()
irBasicBlockType = nullptr;
constExprRate = nullptr;
+ destroyTypeCheckingCache();
+
builtinTypes = decltype(builtinTypes)();
// destroy modules next
loadedModuleCode = decltype(loadedModuleCode)();
diff --git a/source/slang/syntax.h b/source/slang/syntax.h
index ebb9d814b..00f7eb95b 100644
--- a/source/slang/syntax.h
+++ b/source/slang/syntax.h
@@ -77,6 +77,9 @@ namespace Slang
// a vector (this will be added to the cost, if any, of converting
// the element type of the vector)
kConversionCost_ScalarToVector = 1,
+
+ // Conversion is impossible
+ kConversionCost_Impossible = 0xFFFFFFFF,
};
// TODO(tfoley): We should ditch this enumeration