diff options
Diffstat (limited to 'source/slang/slang-ast-builder.cpp')
| -rw-r--r-- | source/slang/slang-ast-builder.cpp | 89 |
1 files changed, 30 insertions, 59 deletions
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index fa8051171..f8c208ac1 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -224,7 +224,7 @@ NodeBase* ASTBuilder::createByNodeType(ASTNodeType nodeType) Type* ASTBuilder::getSpecializedBuiltinType(Type* typeParam, char const* magicTypeName) { - auto declRef = getBuiltinDeclRef(magicTypeName, makeConstArrayViewSingle<Val*>(typeParam)); + auto declRef = getBuiltinDeclRef(magicTypeName, typeParam); auto rsType = DeclRefType::create(this, declRef); return rsType; } @@ -263,9 +263,12 @@ PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, char const* ptrTypeName) ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* elementCount) { - ArrayExpressionType* arrayType = create<ArrayExpressionType>(); - arrayType->baseType = elementType; - arrayType->arrayLength = elementCount; + ArrayExpressionType* arrayType = getOrCreateWithDefaultCtor<ArrayExpressionType>(elementType, elementCount); + if (!arrayType->baseType) + { + arrayType->baseType = elementType; + arrayType->arrayLength = elementCount; + } return arrayType; } @@ -273,18 +276,15 @@ VectorExpressionType* ASTBuilder::getVectorType( Type* elementType, IntVal* elementCount) { - auto vectorGenericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("Vector")); - - auto vectorTypeDecl = vectorGenericDecl->inner; - - auto substitutions = create<GenericSubstitution>(); - substitutions->genericDecl = vectorGenericDecl; - substitutions->args.add(elementType); - substitutions->args.add(elementCount); - - auto declRef = DeclRef<Decl>(vectorTypeDecl, substitutions); - - return as<VectorExpressionType>(DeclRefType::create(this, declRef)); + auto result = getOrCreate<VectorExpressionType>(elementType, elementCount); + if (!result->declRef.decl) + { + auto vectorGenericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("Vector")); + auto vectorTypeDecl = vectorGenericDecl->inner; + auto substitutions = getOrCreate<GenericSubstitution>(vectorGenericDecl, elementType, elementCount); + result->declRef = DeclRef<Decl>(vectorTypeDecl, substitutions); + } + return result; } DifferentialPairType* ASTBuilder::getDifferentialPairType(Type* valueType, Witness* conformanceWitness) @@ -293,10 +293,7 @@ DifferentialPairType* ASTBuilder::getDifferentialPairType(Type* valueType, Witne auto typeDecl = genericDecl->inner; - auto substitutions = create<GenericSubstitution>(); - substitutions->genericDecl = genericDecl; - substitutions->args.add(valueType); - substitutions->args.add(conformanceWitness); + auto substitutions = getOrCreate<GenericSubstitution>(genericDecl, valueType, conformanceWitness); auto declRef = DeclRef<Decl>(typeDecl, substitutions); auto rsType = DeclRefType::create(this, declRef); @@ -311,34 +308,29 @@ DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableInterface() return declRef; } -DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, ConstArrayView<Val*> genericArgs) +DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg) { DeclRef<Decl> declRef; declRef.decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName); if (auto genericDecl = as<GenericDecl>(declRef.decl)) { - if (genericArgs.getCount()) + if (genericArg) { - auto substitutions = create<GenericSubstitution>(); - substitutions->genericDecl = genericDecl; - for (auto arg : genericArgs) - substitutions->args.add(arg); + auto substitutions = getOrCreate<GenericSubstitution>(genericDecl, genericArg); declRef.substitutions = substitutions; } declRef.decl = genericDecl->inner; } else { - SLANG_ASSERT(genericArgs.getCount() == 0); + SLANG_ASSERT(!genericArg); } return declRef; } Type* ASTBuilder::getAndType(Type* left, Type* right) { - auto type = create<AndType>(); - type->left = left; - type->right = right; + auto type = getOrCreate<AndType>(left, right); return type; } @@ -350,46 +342,26 @@ Type* ASTBuilder::getModifiedType(Type* base, Count modifierCount, Val* const* m return type; } -NodeBase* ASTBuilder::_getOrCreateImpl(NodeDesc const& desc, NodeCreateFunc createFunc, void* createFuncUserData) -{ - if(auto found = m_cachedNodes.TryGetValue(desc)) - return *found; - - auto node = createFunc(this, desc, createFuncUserData); - - auto operandCount = desc.operandCount; - NodeBase** operandsCopy = m_arena.allocateAndZeroArray<NodeBase*>(desc.operandCount); - for(Index i = 0; i < operandCount; ++i) - operandsCopy[i] = desc.operands[i]; - - NodeDesc descCopy = desc; - descCopy.operands = operandsCopy; - m_cachedNodes.Add(descCopy, node); - - return node; -} - - Val* ASTBuilder::getUNormModifierVal() { - return _getOrCreate<UNormModifierVal>(); + return getOrCreate<UNormModifierVal>(); } Val* ASTBuilder::getSNormModifierVal() { - return _getOrCreate<SNormModifierVal>(); + return getOrCreate<SNormModifierVal>(); } TypeType* ASTBuilder::getTypeType(Type* type) { - return create<TypeType>(type); + return getOrCreate<TypeType>(type); } bool ASTBuilder::NodeDesc::operator==(NodeDesc const& that) const { if(type != that.type) return false; - if(operandCount != that.operandCount) return false; - for(Index i = 0; i < operandCount; ++i) + if(operands.getCount() != that.operands.getCount()) return false; + for(Index i = 0; i < operands.getCount(); ++i) { // Note: we are comparing the operands directly for identity // (pointer equality) rather than doing the `Val`-level @@ -399,7 +371,7 @@ bool ASTBuilder::NodeDesc::operator==(NodeDesc const& that) const // via a `NodeDesc` *should* all be going through the // deduplication path anyway, as should their operands. // - if(operands[i] != that.operands[i]) return false; + if(operands[i].values.nodeOperand != that.operands[i].values.nodeOperand) return false; } return true; } @@ -407,15 +379,14 @@ HashCode ASTBuilder::NodeDesc::getHashCode() const { Hasher hasher; hasher.hashValue(Int(type)); - hasher.hashValue(operandCount); - for(Index i = 0; i < operandCount; ++i) + for(Index i = 0; i < operands.getCount(); ++i) { // Note: we are hashing the raw pointer value rather // than the content of the value node. This is done // to match the semantics implemented for `==` on // `NodeDesc`. // - hasher.hashValue((void*) operands[i]); + hasher.hashValue(operands[i].values.nodeOperand); } return hasher.getResult(); } |
