summaryrefslogtreecommitdiff
path: root/source/slang/slang-ast-builder.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ast-builder.cpp')
-rw-r--r--source/slang/slang-ast-builder.cpp89
1 files changed, 30 insertions, 59 deletions
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index fa8051171..f8c208ac1 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -224,7 +224,7 @@ NodeBase* ASTBuilder::createByNodeType(ASTNodeType nodeType)
Type* ASTBuilder::getSpecializedBuiltinType(Type* typeParam, char const* magicTypeName)
{
- auto declRef = getBuiltinDeclRef(magicTypeName, makeConstArrayViewSingle<Val*>(typeParam));
+ auto declRef = getBuiltinDeclRef(magicTypeName, typeParam);
auto rsType = DeclRefType::create(this, declRef);
return rsType;
}
@@ -263,9 +263,12 @@ PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, char const* ptrTypeName)
ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* elementCount)
{
- ArrayExpressionType* arrayType = create<ArrayExpressionType>();
- arrayType->baseType = elementType;
- arrayType->arrayLength = elementCount;
+ ArrayExpressionType* arrayType = getOrCreateWithDefaultCtor<ArrayExpressionType>(elementType, elementCount);
+ if (!arrayType->baseType)
+ {
+ arrayType->baseType = elementType;
+ arrayType->arrayLength = elementCount;
+ }
return arrayType;
}
@@ -273,18 +276,15 @@ VectorExpressionType* ASTBuilder::getVectorType(
Type* elementType,
IntVal* elementCount)
{
- auto vectorGenericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("Vector"));
-
- auto vectorTypeDecl = vectorGenericDecl->inner;
-
- auto substitutions = create<GenericSubstitution>();
- substitutions->genericDecl = vectorGenericDecl;
- substitutions->args.add(elementType);
- substitutions->args.add(elementCount);
-
- auto declRef = DeclRef<Decl>(vectorTypeDecl, substitutions);
-
- return as<VectorExpressionType>(DeclRefType::create(this, declRef));
+ auto result = getOrCreate<VectorExpressionType>(elementType, elementCount);
+ if (!result->declRef.decl)
+ {
+ auto vectorGenericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("Vector"));
+ auto vectorTypeDecl = vectorGenericDecl->inner;
+ auto substitutions = getOrCreate<GenericSubstitution>(vectorGenericDecl, elementType, elementCount);
+ result->declRef = DeclRef<Decl>(vectorTypeDecl, substitutions);
+ }
+ return result;
}
DifferentialPairType* ASTBuilder::getDifferentialPairType(Type* valueType, Witness* conformanceWitness)
@@ -293,10 +293,7 @@ DifferentialPairType* ASTBuilder::getDifferentialPairType(Type* valueType, Witne
auto typeDecl = genericDecl->inner;
- auto substitutions = create<GenericSubstitution>();
- substitutions->genericDecl = genericDecl;
- substitutions->args.add(valueType);
- substitutions->args.add(conformanceWitness);
+ auto substitutions = getOrCreate<GenericSubstitution>(genericDecl, valueType, conformanceWitness);
auto declRef = DeclRef<Decl>(typeDecl, substitutions);
auto rsType = DeclRefType::create(this, declRef);
@@ -311,34 +308,29 @@ DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableInterface()
return declRef;
}
-DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, ConstArrayView<Val*> genericArgs)
+DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg)
{
DeclRef<Decl> declRef;
declRef.decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName);
if (auto genericDecl = as<GenericDecl>(declRef.decl))
{
- if (genericArgs.getCount())
+ if (genericArg)
{
- auto substitutions = create<GenericSubstitution>();
- substitutions->genericDecl = genericDecl;
- for (auto arg : genericArgs)
- substitutions->args.add(arg);
+ auto substitutions = getOrCreate<GenericSubstitution>(genericDecl, genericArg);
declRef.substitutions = substitutions;
}
declRef.decl = genericDecl->inner;
}
else
{
- SLANG_ASSERT(genericArgs.getCount() == 0);
+ SLANG_ASSERT(!genericArg);
}
return declRef;
}
Type* ASTBuilder::getAndType(Type* left, Type* right)
{
- auto type = create<AndType>();
- type->left = left;
- type->right = right;
+ auto type = getOrCreate<AndType>(left, right);
return type;
}
@@ -350,46 +342,26 @@ Type* ASTBuilder::getModifiedType(Type* base, Count modifierCount, Val* const* m
return type;
}
-NodeBase* ASTBuilder::_getOrCreateImpl(NodeDesc const& desc, NodeCreateFunc createFunc, void* createFuncUserData)
-{
- if(auto found = m_cachedNodes.TryGetValue(desc))
- return *found;
-
- auto node = createFunc(this, desc, createFuncUserData);
-
- auto operandCount = desc.operandCount;
- NodeBase** operandsCopy = m_arena.allocateAndZeroArray<NodeBase*>(desc.operandCount);
- for(Index i = 0; i < operandCount; ++i)
- operandsCopy[i] = desc.operands[i];
-
- NodeDesc descCopy = desc;
- descCopy.operands = operandsCopy;
- m_cachedNodes.Add(descCopy, node);
-
- return node;
-}
-
-
Val* ASTBuilder::getUNormModifierVal()
{
- return _getOrCreate<UNormModifierVal>();
+ return getOrCreate<UNormModifierVal>();
}
Val* ASTBuilder::getSNormModifierVal()
{
- return _getOrCreate<SNormModifierVal>();
+ return getOrCreate<SNormModifierVal>();
}
TypeType* ASTBuilder::getTypeType(Type* type)
{
- return create<TypeType>(type);
+ return getOrCreate<TypeType>(type);
}
bool ASTBuilder::NodeDesc::operator==(NodeDesc const& that) const
{
if(type != that.type) return false;
- if(operandCount != that.operandCount) return false;
- for(Index i = 0; i < operandCount; ++i)
+ if(operands.getCount() != that.operands.getCount()) return false;
+ for(Index i = 0; i < operands.getCount(); ++i)
{
// Note: we are comparing the operands directly for identity
// (pointer equality) rather than doing the `Val`-level
@@ -399,7 +371,7 @@ bool ASTBuilder::NodeDesc::operator==(NodeDesc const& that) const
// via a `NodeDesc` *should* all be going through the
// deduplication path anyway, as should their operands.
//
- if(operands[i] != that.operands[i]) return false;
+ if(operands[i].values.nodeOperand != that.operands[i].values.nodeOperand) return false;
}
return true;
}
@@ -407,15 +379,14 @@ HashCode ASTBuilder::NodeDesc::getHashCode() const
{
Hasher hasher;
hasher.hashValue(Int(type));
- hasher.hashValue(operandCount);
- for(Index i = 0; i < operandCount; ++i)
+ for(Index i = 0; i < operands.getCount(); ++i)
{
// Note: we are hashing the raw pointer value rather
// than the content of the value node. This is done
// to match the semantics implemented for `==` on
// `NodeDesc`.
//
- hasher.hashValue((void*) operands[i]);
+ hasher.hashValue(operands[i].values.nodeOperand);
}
return hasher.getResult();
}