// slang-ast-builder.cpp #include "slang-ast-builder.h" #include #include "slang-compiler.h" namespace Slang { // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SharedASTBuilder !!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SharedASTBuilder::SharedASTBuilder() { } void SharedASTBuilder::init(Session* session) { m_namePool = session->getNamePool(); // Save the associated session m_session = session; // We just want as a place to store allocations of shared types { RefPtr astBuilder(new ASTBuilder); astBuilder->m_sharedASTBuilder = this; m_astBuilder = astBuilder.detach(); } // Clear the built in types memset(m_builtinTypes, 0, sizeof(m_builtinTypes)); // Create common shared types m_errorType = m_astBuilder->create(); m_bottomType = m_astBuilder->create(); m_initializerListType = m_astBuilder->create(); m_overloadedType = m_astBuilder->create(); // We can just iterate over the class pointers. // NOTE! That this adds the names of the abstract classes too(!) for (Index i = 0; i < Index(ASTNodeType::CountOf); ++i) { const ReflectClassInfo* info = ASTClassInfo::getInfo(ASTNodeType(i)); if (info) { m_sliceToTypeMap.Add(UnownedStringSlice(info->m_name), info); Name* name = m_namePool->getName(String(info->m_name)); m_nameToTypeMap.Add(name, info); } } } const ReflectClassInfo* SharedASTBuilder::findClassInfo(const UnownedStringSlice& slice) { const ReflectClassInfo* typeInfo; return m_sliceToTypeMap.TryGetValue(slice, typeInfo) ? typeInfo : nullptr; } SyntaxClass SharedASTBuilder::findSyntaxClass(const UnownedStringSlice& slice) { const ReflectClassInfo* typeInfo; if (m_sliceToTypeMap.TryGetValue(slice, typeInfo)) { return SyntaxClass(typeInfo); } return SyntaxClass(); } const ReflectClassInfo* SharedASTBuilder::findClassInfo(Name* name) { const ReflectClassInfo* typeInfo; return m_nameToTypeMap.TryGetValue(name, typeInfo) ? typeInfo : nullptr; } SyntaxClass SharedASTBuilder::findSyntaxClass(Name* name) { const ReflectClassInfo* typeInfo; if (m_nameToTypeMap.TryGetValue(name, typeInfo)) { return SyntaxClass(typeInfo); } return SyntaxClass(); } Type* SharedASTBuilder::getStringType() { if (!m_stringType) { auto stringTypeDecl = findMagicDecl("StringType"); m_stringType = DeclRefType::create(m_astBuilder, makeDeclRef(stringTypeDecl)); } return m_stringType; } Type* SharedASTBuilder::getNativeStringType() { if (!m_nativeStringType) { auto nativeStringTypeDecl = findMagicDecl("NativeStringType"); m_nativeStringType = DeclRefType::create(m_astBuilder, makeDeclRef(nativeStringTypeDecl)); } return m_nativeStringType; } Type* SharedASTBuilder::getEnumTypeType() { if (!m_enumTypeType) { auto enumTypeTypeDecl = findMagicDecl("EnumTypeType"); m_enumTypeType = DeclRefType::create(m_astBuilder, makeDeclRef(enumTypeTypeDecl)); } return m_enumTypeType; } Type* SharedASTBuilder::getDynamicType() { if (!m_dynamicType) { auto dynamicTypeDecl = findMagicDecl("DynamicType"); m_dynamicType = DeclRefType::create(m_astBuilder, makeDeclRef(dynamicTypeDecl)); } return m_dynamicType; } Type* SharedASTBuilder::getNullPtrType() { if (!m_nullPtrType) { auto nullPtrTypeDecl = findMagicDecl("NullPtrType"); m_nullPtrType = DeclRefType::create(m_astBuilder, makeDeclRef(nullPtrTypeDecl)); } return m_nullPtrType; } Type* SharedASTBuilder::getNoneType() { if (!m_noneType) { auto noneTypeDecl = findMagicDecl("NoneType"); m_noneType = DeclRefType::create(m_astBuilder, makeDeclRef(noneTypeDecl)); } return m_noneType; } Type* SharedASTBuilder::getDiffInterfaceType() { if (!m_diffInterfaceType) { auto decl = findMagicDecl("DifferentiableType"); m_diffInterfaceType = DeclRefType::create(m_astBuilder, makeDeclRef(decl)); } return m_diffInterfaceType; } SharedASTBuilder::~SharedASTBuilder() { // Release built in types.. for (Index i = 0; i < SLANG_COUNT_OF(m_builtinTypes); ++i) { m_builtinTypes[i] = nullptr; } if (m_astBuilder) { m_astBuilder->releaseReference(); } } void SharedASTBuilder::registerBuiltinDecl(Decl* decl, BuiltinTypeModifier* modifier) { auto type = DeclRefType::create(m_astBuilder, DeclRef(decl, nullptr)); m_builtinTypes[Index(modifier->tag)] = type; } void SharedASTBuilder::registerBuiltinRequirementDecl(Decl* decl, BuiltinRequirementModifier* modifier) { m_builtinRequirementDecls[modifier->kind] = decl; } void SharedASTBuilder::registerMagicDecl(Decl* decl, MagicTypeModifier* modifier) { // In some cases the modifier will have been applied to the // "inner" declaration of a `GenericDecl`, but what we // actually want to register is the generic itself. // auto declToRegister = decl; if (auto genericDecl = as(decl->parentDecl)) declToRegister = genericDecl; m_magicDecls[modifier->magicName] = declToRegister; } Decl* SharedASTBuilder::findMagicDecl(const String& name) { return m_magicDecls[name].GetValue(); } Decl* SharedASTBuilder::tryFindMagicDecl(const String& name) { if (m_magicDecls.ContainsKey(name)) { return m_magicDecls[name].GetValue(); } else { return nullptr; } } // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTBuilder !!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTBuilder::ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name): m_sharedASTBuilder(sharedASTBuilder), m_name(name), m_id(sharedASTBuilder->m_id++), m_arena(2048) { SLANG_ASSERT(sharedASTBuilder); } ASTBuilder::ASTBuilder(): m_sharedASTBuilder(nullptr), m_id(-1), m_arena(2048) { m_name = "SharedASTBuilder::m_astBuilder"; } ASTBuilder::~ASTBuilder() { for (NodeBase* node : m_dtorNodes) { const ReflectClassInfo* info = ASTClassInfo::getInfo(node->astNodeType); SLANG_ASSERT(info->m_destructorFunc); info->m_destructorFunc(node); } } NodeBase* ASTBuilder::createByNodeType(ASTNodeType nodeType) { const ReflectClassInfo* info = ASTClassInfo::getInfo(nodeType); auto createFunc = info->m_createFunc; SLANG_ASSERT(createFunc); if (!createFunc) { return nullptr; } return (NodeBase*)createFunc(this); } Type* ASTBuilder::getSpecializedBuiltinType(Type* typeParam, char const* magicTypeName) { auto declRef = getBuiltinDeclRef(magicTypeName, typeParam); auto rsType = DeclRefType::create(this, declRef); return rsType; } PtrType* ASTBuilder::getPtrType(Type* valueType) { return dynamicCast(getPtrType(valueType, "PtrType")); } // Construct the type `Out` OutType* ASTBuilder::getOutType(Type* valueType) { return dynamicCast(getPtrType(valueType, "OutType")); } InOutType* ASTBuilder::getInOutType(Type* valueType) { return dynamicCast(getPtrType(valueType, "InOutType")); } RefType* ASTBuilder::getRefType(Type* valueType) { return dynamicCast(getPtrType(valueType, "RefType")); } OptionalType* ASTBuilder::getOptionalType(Type* valueType) { auto rsType = getSpecializedBuiltinType(valueType, "OptionalType"); return as(rsType); } PtrTypeBase* ASTBuilder::getPtrType(Type* valueType, char const* ptrTypeName) { return as(getSpecializedBuiltinType(valueType, ptrTypeName)); } ArrayExpressionType* ASTBuilder::getArrayType(Type* elementType, IntVal* elementCount) { if (!elementCount) elementCount = getIntVal(getIntType(), kUnsizedArrayMagicLength); auto result = getOrCreate(elementType, elementCount); if (!result->declRef.decl) { auto arrayGenericDecl = as(m_sharedASTBuilder->findMagicDecl("ArrayType")); auto arrayTypeDecl = arrayGenericDecl->inner; auto substitutions = getOrCreate(arrayGenericDecl, elementType, elementCount); result->declRef = DeclRef(arrayTypeDecl, substitutions); } return result; } VectorExpressionType* ASTBuilder::getVectorType( Type* elementType, IntVal* elementCount) { auto result = getOrCreate(elementType, elementCount); if (!result->declRef.decl) { auto vectorGenericDecl = as(m_sharedASTBuilder->findMagicDecl("Vector")); auto vectorTypeDecl = vectorGenericDecl->inner; auto substitutions = getOrCreate(vectorGenericDecl, elementType, elementCount); result->declRef = DeclRef(vectorTypeDecl, substitutions); } return result; } DifferentialPairType* ASTBuilder::getDifferentialPairType( Type* valueType, Witness* primalIsDifferentialWitness) { auto genericDecl = dynamicCast(m_sharedASTBuilder->findMagicDecl("DifferentialPairType")); auto typeDecl = genericDecl->inner; auto substitutions = getOrCreate( genericDecl, valueType, primalIsDifferentialWitness); auto declRef = DeclRef(typeDecl, substitutions); auto rsType = DeclRefType::create(this, declRef); return as(rsType); } DeclRef ASTBuilder::getDifferentiableInterface() { DeclRef declRef; declRef.decl = dynamicCast(m_sharedASTBuilder->findMagicDecl("DifferentiableType")); return declRef; } bool ASTBuilder::isDifferentiableInterfaceAvailable() { return (m_sharedASTBuilder->tryFindMagicDecl("DifferentiableType") != nullptr); } MeshOutputType* ASTBuilder::getMeshOutputTypeFromModifier( HLSLMeshShaderOutputModifier* modifier, Type* elementType, IntVal* maxElementCount) { SLANG_ASSERT(modifier); SLANG_ASSERT(elementType); SLANG_ASSERT(maxElementCount); const char* declName = as(modifier) ? "VerticesType" : as(modifier) ? "IndicesType" : as(modifier) ? "PrimitivesType" : (SLANG_UNEXPECTED("Unhandled mesh output modifier"), nullptr); auto genericDecl = dynamicCast(m_sharedASTBuilder->findMagicDecl(declName)); auto typeDecl = genericDecl->inner; auto substitutions = getOrCreate( genericDecl, elementType, maxElementCount); auto declRef = DeclRef(typeDecl, substitutions); auto rsType = DeclRefType::create(this, declRef); return as(rsType); } DeclRef ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg) { DeclRef declRef; declRef.decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName); if (auto genericDecl = as(declRef.decl)) { if (genericArg) { auto substitutions = getOrCreate(genericDecl, genericArg); declRef.substitutions = substitutions; } declRef.decl = genericDecl->inner; } else { SLANG_ASSERT(!genericArg); } return declRef; } Type* ASTBuilder::getAndType(Type* left, Type* right) { auto type = getOrCreate(left, right); return type; } Type* ASTBuilder::getModifiedType(Type* base, Count modifierCount, Val* const* modifiers) { auto type = create(); type->base = base; type->modifiers.addRange(modifiers, modifierCount); return type; } Val* ASTBuilder::getUNormModifierVal() { return getOrCreate(); } Val* ASTBuilder::getSNormModifierVal() { return getOrCreate(); } Val* ASTBuilder::getNoDiffModifierVal() { return getOrCreate(); } TypeType* ASTBuilder::getTypeType(Type* type) { return getOrCreate(type); } bool ASTBuilder::NodeDesc::operator==(NodeDesc const& that) const { if(type != that.type) return false; if(operands.getCount() != that.operands.getCount()) return false; for(Index i = 0; i < operands.getCount(); ++i) { // Note: we are comparing the operands directly for identity // (pointer equality) rather than doing the `Val`-level // equality check. // // The rationale here is that nodes that will be created // via a `NodeDesc` *should* all be going through the // deduplication path anyway, as should their operands. // if(operands[i].values.nodeOperand != that.operands[i].values.nodeOperand) return false; } return true; } HashCode ASTBuilder::NodeDesc::getHashCode() const { Hasher hasher; hasher.hashValue(Int(type)); for(Index i = 0; i < operands.getCount(); ++i) { // Note: we are hashing the raw pointer value rather // than the content of the value node. This is done // to match the semantics implemented for `==` on // `NodeDesc`. // hasher.hashValue(operands[i].values.nodeOperand); } return hasher.getResult(); } } // namespace Slang