#include "syntax.h" #include "syntax-visitors.h" #include #include namespace Slang { namespace Compiler { // BasicExpressionType bool BasicExpressionType::EqualsImpl(ExpressionType * type) { auto basicType = dynamic_cast(type); if (basicType == nullptr) return false; return basicType->BaseType == BaseType; } ExpressionType* BasicExpressionType::CreateCanonicalType() { // A basic type is already canonical, in our setup return this; } CoreLib::Basic::String BasicExpressionType::ToString() { CoreLib::Basic::StringBuilder res; switch (BaseType) { case Compiler::BaseType::Int: res.Append("int"); break; case Compiler::BaseType::UInt: res.Append("uint"); break; case Compiler::BaseType::UInt64: res.Append("uint64_t"); break; case Compiler::BaseType::Bool: res.Append("bool"); break; case Compiler::BaseType::Float: res.Append("float"); break; case Compiler::BaseType::Void: res.Append("void"); break; default: break; } return res.ProduceString(); } RefPtr ProgramSyntaxNode::Accept(SyntaxVisitor * visitor) { return visitor->VisitProgram(this); } RefPtr FunctionSyntaxNode::Accept(SyntaxVisitor * visitor) { return visitor->VisitFunction(this); } // RefPtr ScopeDecl::Accept(SyntaxVisitor * visitor) { return visitor->VisitScopeDecl(this); } // RefPtr BlockStatementSyntaxNode::Accept(SyntaxVisitor * visitor) { return visitor->VisitBlockStatement(this); } RefPtr BreakStatementSyntaxNode::Accept(SyntaxVisitor * visitor) { return visitor->VisitBreakStatement(this); } RefPtr ContinueStatementSyntaxNode::Accept(SyntaxVisitor * visitor) { return visitor->VisitContinueStatement(this); } RefPtr DoWhileStatementSyntaxNode::Accept(SyntaxVisitor * visitor) { return visitor->VisitDoWhileStatement(this); } RefPtr EmptyStatementSyntaxNode::Accept(SyntaxVisitor * visitor) { return visitor->VisitEmptyStatement(this); } RefPtr ForStatementSyntaxNode::Accept(SyntaxVisitor * visitor) { return visitor->VisitForStatement(this); } RefPtr IfStatementSyntaxNode::Accept(SyntaxVisitor * visitor) { return visitor->VisitIfStatement(this); } RefPtr ReturnStatementSyntaxNode::Accept(SyntaxVisitor * visitor) { return visitor->VisitReturnStatement(this); } RefPtr VarDeclrStatementSyntaxNode::Accept(SyntaxVisitor * visitor) { return visitor->VisitVarDeclrStatement(this); } RefPtr Variable::Accept(SyntaxVisitor * visitor) { return visitor->VisitDeclrVariable(this); } RefPtr WhileStatementSyntaxNode::Accept(SyntaxVisitor * visitor) { return visitor->VisitWhileStatement(this); } RefPtr ExpressionStatementSyntaxNode::Accept(SyntaxVisitor * visitor) { return visitor->VisitExpressionStatement(this); } RefPtr ConstantExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) { return visitor->VisitConstantExpression(this); } RefPtr IndexExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) { return visitor->VisitIndexExpression(this); } RefPtr MemberExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) { return visitor->VisitMemberExpression(this); } // SwizzleExpr RefPtr SwizzleExpr::Accept(SyntaxVisitor * visitor) { return visitor->VisitSwizzleExpression(this); } // DerefExpr RefPtr DerefExpr::Accept(SyntaxVisitor * /*visitor*/) { // throw "unimplemented"; return this; } // RefPtr InvokeExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) { return visitor->VisitInvokeExpression(this); } RefPtr TypeCastExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) { return visitor->VisitTypeCastExpression(this); } RefPtr VarExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) { return visitor->VisitVarExpression(this); } // OverloadedExpr RefPtr OverloadedExpr::Accept(SyntaxVisitor * /*visitor*/) { // throw "unimplemented"; return this; } // RefPtr ParameterSyntaxNode::Accept(SyntaxVisitor * visitor) { return visitor->VisitParameter(this); } // ImportDecl RefPtr ImportDecl::Accept(SyntaxVisitor * visitor) { visitor->visitImportDecl(this); return this; } // RefPtr StructField::Accept(SyntaxVisitor * visitor) { return visitor->VisitStructField(this); } RefPtr StructSyntaxNode::Accept(SyntaxVisitor * visitor) { return visitor->VisitStruct(this); } RefPtr ClassSyntaxNode::Accept(SyntaxVisitor * visitor) { return visitor->VisitClass(this); } RefPtr TypeDefDecl::Accept(SyntaxVisitor * visitor) { return visitor->VisitTypeDefDecl(this); } RefPtr DiscardStatementSyntaxNode::Accept(SyntaxVisitor * visitor) { return visitor->VisitDiscardStatement(this); } // BasicExpressionType BasicExpressionType* BasicExpressionType::GetScalarType() { return this; } // bool ExpressionType::Equals(ExpressionType * type) { return GetCanonicalType()->EqualsImpl(type->GetCanonicalType()); } bool ExpressionType::Equals(RefPtr type) { return Equals(type.Ptr()); } bool ExpressionType::EqualsVal(Val* val) { if (auto type = dynamic_cast(val)) return const_cast(this)->Equals(type); return false; } NamedExpressionType* ExpressionType::AsNamedType() { return dynamic_cast(this); } RefPtr ExpressionType::SubstituteImpl(Substitutions* subst, int* ioDiff) { int diff = 0; auto canSubst = GetCanonicalType()->SubstituteImpl(subst, &diff); // If nothing changed, then don't drop any sugar that is applied if (!diff) return this; // If the canonical type changed, then we return a canonical type, // rather than try to re-construct any amount of sugar (*ioDiff)++; return canSubst; } ExpressionType* ExpressionType::GetCanonicalType() { if (!this) return nullptr; ExpressionType* et = const_cast(this); if (!et->canonicalType) { // TODO(tfoley): worry about thread safety here? et->canonicalType = et->CreateCanonicalType(); assert(et->canonicalType); } return et->canonicalType; } bool ExpressionType::IsTextureOrSampler() { return IsTexture() || IsSampler(); } bool ExpressionType::IsStruct() { auto declRefType = AsDeclRefType(); if (!declRefType) return false; auto structDeclRef = declRefType->declRef.As(); if (!structDeclRef) return false; return true; } bool ExpressionType::IsClass() { auto declRefType = AsDeclRefType(); if (!declRefType) return false; auto classDeclRef = declRefType->declRef.As(); if (!classDeclRef) return false; return true; } #if 0 RefPtr ExpressionType::Bool; RefPtr ExpressionType::UInt; RefPtr ExpressionType::Int; RefPtr ExpressionType::Float; RefPtr ExpressionType::Float2; RefPtr ExpressionType::Void; #endif RefPtr ExpressionType::Error; RefPtr ExpressionType::initializerListType; RefPtr ExpressionType::Overloaded; Dictionary> ExpressionType::sBuiltinTypes; Dictionary ExpressionType::sMagicDecls; List> ExpressionType::sCanonicalTypes; void ExpressionType::Init() { Error = new ErrorType(); initializerListType = new InitializerListType(); Overloaded = new OverloadGroupType(); } void ExpressionType::Finalize() { Error = nullptr; initializerListType = nullptr; Overloaded = nullptr; // Note(tfoley): This seems to be just about the only way to clear out a List sCanonicalTypes = List>(); sBuiltinTypes = Dictionary>(); sMagicDecls = Dictionary(); } bool ArrayExpressionType::EqualsImpl(ExpressionType * type) { auto arrType = type->AsArrayType(); if (!arrType) return false; return (ArrayLength == arrType->ArrayLength && BaseType->Equals(arrType->BaseType.Ptr())); } ExpressionType* ArrayExpressionType::CreateCanonicalType() { auto canonicalBaseType = BaseType->GetCanonicalType(); auto canonicalArrayType = new ArrayExpressionType(); sCanonicalTypes.Add(canonicalArrayType); canonicalArrayType->BaseType = canonicalBaseType; canonicalArrayType->ArrayLength = ArrayLength; return canonicalArrayType; } int ArrayExpressionType::GetHashCode() { if (ArrayLength) return (BaseType->GetHashCode() * 16777619) ^ ArrayLength->GetHashCode(); else return BaseType->GetHashCode(); } CoreLib::Basic::String ArrayExpressionType::ToString() { if (ArrayLength) return BaseType->ToString() + "[" + ArrayLength->ToString() + "]"; else return BaseType->ToString() + "[]"; } RefPtr GenericAppExpr::Accept(SyntaxVisitor * visitor) { return visitor->VisitGenericApp(this); } // DeclRefType String DeclRefType::ToString() { return declRef.GetName(); } int DeclRefType::GetHashCode() { return (declRef.GetHashCode() * 16777619) ^ (int)(typeid(this).hash_code()); } bool DeclRefType::EqualsImpl(ExpressionType * type) { if (auto declRefType = type->AsDeclRefType()) { return declRef.Equals(declRefType->declRef); } return false; } ExpressionType* DeclRefType::CreateCanonicalType() { // A declaration reference is already canonical return this; } RefPtr DeclRefType::SubstituteImpl(Substitutions* subst, int* ioDiff) { if (!subst) return this; // the case we especially care about is when this type references a declaration // of a generic parameter, since that is what we might be substituting... if (auto genericTypeParamDecl = dynamic_cast(declRef.GetDecl())) { // search for a substitution that might apply to us for (auto s = subst; s; s = s->outer.Ptr()) { // the generic decl associated with the substitution list must be // the generic decl that declared this parameter auto genericDecl = s->genericDecl; if (genericDecl != genericTypeParamDecl->ParentDecl) continue; int index = 0; for (auto m : genericDecl->Members) { if (m.Ptr() == genericTypeParamDecl) { // We've found it, so return the corresponding specialization argument (*ioDiff)++; return s->args[index]; } else if(auto typeParam = m.As()) { index++; } else if(auto valParam = m.As()) { index++; } else { } } } } int diff = 0; DeclRef substDeclRef = declRef.SubstituteImpl(subst, &diff); if (!diff) return this; // Make sure to record the difference! *ioDiff += diff; // Re-construct the type in case we are using a specialized sub-class return DeclRefType::Create(substDeclRef); } static RefPtr ExtractGenericArgType(RefPtr val) { auto type = val.As(); assert(type.Ptr()); return type; } static RefPtr ExtractGenericArgInteger(RefPtr val) { auto intVal = val.As(); assert(intVal.Ptr()); return intVal; } // TODO: need to figure out how to unify this with the logic // in the generic case... DeclRefType* DeclRefType::Create(DeclRef declRef) { if (auto builtinMod = declRef.GetDecl()->FindModifier()) { auto type = new BasicExpressionType(builtinMod->tag); type->declRef = declRef; return type; } else if (auto magicMod = declRef.GetDecl()->FindModifier()) { Substitutions* subst = declRef.substitutions.Ptr(); if (magicMod->name == "SamplerState") { auto type = new SamplerStateType(); type->declRef = declRef; type->flavor = SamplerStateType::Flavor(magicMod->tag); return type; } else if (magicMod->name == "Vector") { assert(subst && subst->args.Count() == 2); auto vecType = new VectorExpressionType(); vecType->declRef = declRef; vecType->elementType = ExtractGenericArgType(subst->args[0]); vecType->elementCount = ExtractGenericArgInteger(subst->args[1]); return vecType; } else if (magicMod->name == "Matrix") { assert(subst && subst->args.Count() == 3); auto matType = new MatrixExpressionType(); matType->declRef = declRef; return matType; } else if (magicMod->name == "Texture") { assert(subst && subst->args.Count() >= 1); auto textureType = new TextureType( TextureType::Flavor(magicMod->tag), ExtractGenericArgType(subst->args[0])); textureType->declRef = declRef; return textureType; } else if (magicMod->name == "TextureSampler") { assert(subst && subst->args.Count() >= 1); auto textureType = new TextureSamplerType( TextureType::Flavor(magicMod->tag), ExtractGenericArgType(subst->args[0])); textureType->declRef = declRef; return textureType; } else if (magicMod->name == "GLSLImageType") { assert(subst && subst->args.Count() >= 1); auto textureType = new GLSLImageType( TextureType::Flavor(magicMod->tag), ExtractGenericArgType(subst->args[0])); textureType->declRef = declRef; return textureType; } #define CASE(n,T) \ else if(magicMod->name == #n) { \ assert(subst && subst->args.Count() == 1); \ auto type = new T(); \ type->elementType = ExtractGenericArgType(subst->args[0]); \ type->declRef = declRef; \ return type; \ } CASE(ConstantBuffer, ConstantBufferType) CASE(TextureBuffer, TextureBufferType) CASE(GLSLInputParameterBlockType, GLSLInputParameterBlockType) CASE(GLSLOutputParameterBlockType, GLSLOutputParameterBlockType) CASE(GLSLShaderStorageBufferType, GLSLShaderStorageBufferType) CASE(PackedBuffer, PackedBufferType) CASE(Uniform, UniformBufferType) CASE(Patch, PatchType) CASE(HLSLBufferType, HLSLBufferType) CASE(HLSLStructuredBufferType, HLSLStructuredBufferType) CASE(HLSLRWBufferType, HLSLRWBufferType) CASE(HLSLRWStructuredBufferType, HLSLRWStructuredBufferType) CASE(HLSLAppendStructuredBufferType, HLSLAppendStructuredBufferType) CASE(HLSLConsumeStructuredBufferType, HLSLConsumeStructuredBufferType) CASE(HLSLInputPatchType, HLSLInputPatchType) CASE(HLSLOutputPatchType, HLSLOutputPatchType) CASE(HLSLPointStreamType, HLSLPointStreamType) CASE(HLSLLineStreamType, HLSLPointStreamType) CASE(HLSLTriangleStreamType, HLSLPointStreamType) #undef CASE // "magic" builtin types which have no generic parameters #define CASE(n,T) \ else if(magicMod->name == #n) { \ auto type = new T(); \ type->declRef = declRef; \ return type; \ } CASE(HLSLByteAddressBufferType, HLSLByteAddressBufferType) CASE(HLSLRWByteAddressBufferType, HLSLRWByteAddressBufferType) CASE(UntypedBufferResourceType, UntypedBufferResourceType) CASE(GLSLInputAttachmentType, GLSLInputAttachmentType) #undef CASE else { throw "unimplemented"; } } else { return new DeclRefType(declRef); } } // OverloadGroupType String OverloadGroupType::ToString() { return "overload group"; } bool OverloadGroupType::EqualsImpl(ExpressionType * /*type*/) { return false; } ExpressionType* OverloadGroupType::CreateCanonicalType() { return this; } int OverloadGroupType::GetHashCode() { return (int)(int64_t)(void*)this; } // InitializerListType String InitializerListType::ToString() { return "initializer list"; } bool InitializerListType::EqualsImpl(ExpressionType * /*type*/) { return false; } ExpressionType* InitializerListType::CreateCanonicalType() { return this; } int InitializerListType::GetHashCode() { return (int)(int64_t)(void*)this; } // ErrorType String ErrorType::ToString() { return "error"; } bool ErrorType::EqualsImpl(ExpressionType* type) { if (auto errorType = type->As()) return true; return false; } ExpressionType* ErrorType::CreateCanonicalType() { return this; } int ErrorType::GetHashCode() { return (int)(int64_t)(void*)this; } // NamedExpressionType String NamedExpressionType::ToString() { return declRef.GetName(); } bool NamedExpressionType::EqualsImpl(ExpressionType * /*type*/) { assert(!"unreachable"); return false; } ExpressionType* NamedExpressionType::CreateCanonicalType() { return declRef.GetType()->GetCanonicalType(); } int NamedExpressionType::GetHashCode() { assert(!"unreachable"); return 0; } // FuncType String FuncType::ToString() { // TODO: a better approach than this if (declRef) return declRef.GetName(); else return "/* unknown FuncType */"; } bool FuncType::EqualsImpl(ExpressionType * type) { if (auto funcType = type->As()) { return declRef == funcType->declRef; } return false; } ExpressionType* FuncType::CreateCanonicalType() { return this; } int FuncType::GetHashCode() { return declRef.GetHashCode(); } // TypeType String TypeType::ToString() { StringBuilder sb; sb << "typeof(" << type->ToString() << ")"; return sb.ProduceString(); } bool TypeType::EqualsImpl(ExpressionType * t) { if (auto typeType = t->As()) { return t->Equals(typeType->type); } return false; } ExpressionType* TypeType::CreateCanonicalType() { auto canType = new TypeType(type->GetCanonicalType()); sCanonicalTypes.Add(canType); return canType; } int TypeType::GetHashCode() { assert(!"unreachable"); return 0; } // GenericDeclRefType String GenericDeclRefType::ToString() { // TODO: what is appropriate here? return ""; } bool GenericDeclRefType::EqualsImpl(ExpressionType * type) { if (auto genericDeclRefType = type->As()) { return declRef.Equals(genericDeclRefType->declRef); } return false; } int GenericDeclRefType::GetHashCode() { return declRef.GetHashCode(); } ExpressionType* GenericDeclRefType::CreateCanonicalType() { return this; } // ArithmeticExpressionType // VectorExpressionType String VectorExpressionType::ToString() { StringBuilder sb; sb << "vector<" << elementType->ToString() << "," << elementCount->ToString() << ">"; return sb.ProduceString(); } BasicExpressionType* VectorExpressionType::GetScalarType() { return elementType->AsBasicType(); } // MatrixExpressionType String MatrixExpressionType::ToString() { StringBuilder sb; sb << "matrix<" << getElementType()->ToString() << "," << getRowCount()->ToString() << "," << getColumnCount()->ToString() << ">"; return sb.ProduceString(); } BasicExpressionType* MatrixExpressionType::GetScalarType() { return getElementType()->AsBasicType(); } ExpressionType* MatrixExpressionType::getElementType() { return this->declRef.substitutions->args[0].As().Ptr(); } IntVal* MatrixExpressionType::getRowCount() { return this->declRef.substitutions->args[1].As().Ptr(); } IntVal* MatrixExpressionType::getColumnCount() { return this->declRef.substitutions->args[2].As().Ptr(); } // #if 0 String GetOperatorFunctionName(Operator op) { switch (op) { case Operator::Add: case Operator::AddAssign: return "+"; case Operator::Sub: case Operator::SubAssign: return "-"; case Operator::Neg: return "-"; case Operator::Not: return "!"; case Operator::BitNot: return "~"; case Operator::PreInc: case Operator::PostInc: return "++"; case Operator::PreDec: case Operator::PostDec: return "--"; case Operator::Mul: case Operator::MulAssign: return "*"; case Operator::Div: case Operator::DivAssign: return "/"; case Operator::Mod: case Operator::ModAssign: return "%"; case Operator::Lsh: case Operator::LshAssign: return "<<"; case Operator::Rsh: case Operator::RshAssign: return ">>"; case Operator::Eql: return "=="; case Operator::Neq: return "!="; case Operator::Greater: return ">"; case Operator::Less: return "<"; case Operator::Geq: return ">="; case Operator::Leq: return "<="; case Operator::BitAnd: case Operator::AndAssign: return "&"; case Operator::BitXor: case Operator::XorAssign: return "^"; case Operator::BitOr: case Operator::OrAssign: return "|"; case Operator::And: return "&&"; case Operator::Or: return "||"; case Operator::Sequence: return ","; case Operator::Select: return "?:"; case Operator::Assign: return "="; default: return ""; } } #endif String OperatorToString(Operator op) { switch (op) { case Slang::Compiler::Operator::Neg: return "-"; case Slang::Compiler::Operator::Not: return "!"; case Slang::Compiler::Operator::PreInc: return "++"; case Slang::Compiler::Operator::PreDec: return "--"; case Slang::Compiler::Operator::PostInc: return "++"; case Slang::Compiler::Operator::PostDec: return "--"; case Slang::Compiler::Operator::Mul: case Slang::Compiler::Operator::MulAssign: return "*"; case Slang::Compiler::Operator::Div: case Slang::Compiler::Operator::DivAssign: return "/"; case Slang::Compiler::Operator::Mod: case Slang::Compiler::Operator::ModAssign: return "%"; case Slang::Compiler::Operator::Add: case Slang::Compiler::Operator::AddAssign: return "+"; case Slang::Compiler::Operator::Sub: case Slang::Compiler::Operator::SubAssign: return "-"; case Slang::Compiler::Operator::Lsh: case Slang::Compiler::Operator::LshAssign: return "<<"; case Slang::Compiler::Operator::Rsh: case Slang::Compiler::Operator::RshAssign: return ">>"; case Slang::Compiler::Operator::Eql: return "=="; case Slang::Compiler::Operator::Neq: return "!="; case Slang::Compiler::Operator::Greater: return ">"; case Slang::Compiler::Operator::Less: return "<"; case Slang::Compiler::Operator::Geq: return ">="; case Slang::Compiler::Operator::Leq: return "<="; case Slang::Compiler::Operator::BitAnd: case Slang::Compiler::Operator::AndAssign: return "&"; case Slang::Compiler::Operator::BitXor: case Slang::Compiler::Operator::XorAssign: return "^"; case Slang::Compiler::Operator::BitOr: case Slang::Compiler::Operator::OrAssign: return "|"; case Slang::Compiler::Operator::And: return "&&"; case Slang::Compiler::Operator::Or: return "||"; case Slang::Compiler::Operator::Assign: return "="; default: return "ERROR"; } } // TypeExp TypeExp TypeExp::Accept(SyntaxVisitor* visitor) { return visitor->VisitTypeExp(*this); } // BuiltinTypeModifier // MagicTypeModifier // GenericDecl RefPtr GenericDecl::Accept(SyntaxVisitor * visitor) { return visitor->VisitGenericDecl(this); } // GenericTypeParamDecl RefPtr GenericTypeParamDecl::Accept(SyntaxVisitor * /*visitor*/) { //throw "unimplemented"; return this; } // GenericTypeConstraintDecl RefPtr GenericTypeConstraintDecl::Accept(SyntaxVisitor * visitor) { return this; } // GenericValueParamDecl RefPtr GenericValueParamDecl::Accept(SyntaxVisitor * /*visitor*/) { //throw "unimplemented"; return this; } // GenericParamIntVal bool GenericParamIntVal::EqualsVal(Val* val) { if (auto genericParamVal = dynamic_cast(val)) { return declRef.Equals(genericParamVal->declRef); } return false; } String GenericParamIntVal::ToString() { return declRef.GetName(); } int GenericParamIntVal::GetHashCode() { return declRef.GetHashCode() ^ 0xFFFF; } RefPtr GenericParamIntVal::SubstituteImpl(Substitutions* subst, int* ioDiff) { // search for a substitution that might apply to us for (auto s = subst; s; s = s->outer.Ptr()) { // the generic decl associated with the substitution list must be // the generic decl that declared this parameter auto genericDecl = s->genericDecl; if (genericDecl != declRef.GetDecl()->ParentDecl) continue; int index = 0; for (auto m : genericDecl->Members) { if (m.Ptr() == declRef.GetDecl()) { // We've found it, so return the corresponding specialization argument (*ioDiff)++; return s->args[index]; } else if(auto typeParam = m.As()) { index++; } else if(auto valParam = m.As()) { index++; } else { } } } // Nothing found: don't substittue. return this; } // ExtensionDecl RefPtr ExtensionDecl::Accept(SyntaxVisitor * visitor) { visitor->VisitExtensionDecl(this); return this; } // ConstructorDecl RefPtr ConstructorDecl::Accept(SyntaxVisitor * visitor) { visitor->VisitConstructorDecl(this); return this; } // SubscriptDecl RefPtr SubscriptDecl::Accept(SyntaxVisitor * visitor) { visitor->visitSubscriptDecl(this); return this; } // AccessorDecl RefPtr AccessorDecl::Accept(SyntaxVisitor * visitor) { visitor->visitAccessorDecl(this); return this; } // Substitutions RefPtr Substitutions::SubstituteImpl(Substitutions* subst, int* ioDiff) { if (!this) return nullptr; int diff = 0; auto outerSubst = outer->SubstituteImpl(subst, &diff); List> substArgs; for (auto a : args) { substArgs.Add(a->SubstituteImpl(subst, &diff)); } if (!diff) return this; (*ioDiff)++; auto substSubst = new Substitutions(); substSubst->genericDecl = genericDecl; substSubst->args = substArgs; return substSubst; } bool Substitutions::Equals(Substitutions* subst) { // both must be NULL, or non-NULL if (!this || !subst) return !this && !subst; if (genericDecl != subst->genericDecl) return false; int argCount = args.Count(); assert(args.Count() == subst->args.Count()); for (int aa = 0; aa < argCount; ++aa) { if (!args[aa]->EqualsVal(subst->args[aa].Ptr())) return false; } if (!outer->Equals(subst->outer.Ptr())) return false; return true; } // DeclRef RefPtr DeclRef::Substitute(RefPtr type) const { // No substitutions? Easy. if (!substitutions) return type; // Otherwise we need to recurse on the type structure // and apply substitutions where it makes sense return type->Substitute(substitutions.Ptr()).As(); } DeclRef DeclRef::Substitute(DeclRef declRef) const { if(!substitutions) return declRef; int diff = 0; return declRef.SubstituteImpl(substitutions.Ptr(), &diff); } RefPtr DeclRef::Substitute(RefPtr expr) const { // No substitutions? Easy. if (!substitutions) return expr; assert(!"unimplemented"); return expr; } DeclRef DeclRef::SubstituteImpl(Substitutions* subst, int* ioDiff) { if (!substitutions) return *this; int diff = 0; RefPtr substSubst = substitutions->SubstituteImpl(subst, &diff); if (!diff) return *this; *ioDiff += diff; DeclRef substDeclRef; substDeclRef.decl = decl; substDeclRef.substitutions = substSubst; return substDeclRef; } // Check if this is an equivalent declaration reference to another bool DeclRef::Equals(DeclRef const& declRef) const { if (decl != declRef.decl) return false; if (!substitutions->Equals(declRef.substitutions.Ptr())) return false; return true; } // Convenience accessors for common properties of declarations String const& DeclRef::GetName() const { return decl->Name.Content; } DeclRef DeclRef::GetParent() const { auto parentDecl = decl->ParentDecl; if (auto parentGeneric = dynamic_cast(parentDecl)) { // We need to strip away one layer of specialization assert(substitutions); return DeclRef(parentGeneric, substitutions->outer); } else { // If the parent isn't a generic, then it must // use the same specializations as this declaration return DeclRef(parentDecl, substitutions); } } int DeclRef::GetHashCode() const { auto rs = PointerHash<1>::GetHashCode(decl); if (substitutions) { rs *= 16777619; rs ^= substitutions->GetHashCode(); } return rs; } // Val RefPtr Val::Substitute(Substitutions* subst) { if (!this) return nullptr; if (!subst) return this; int diff = 0; return SubstituteImpl(subst, &diff); } RefPtr Val::SubstituteImpl(Substitutions* /*subst*/, int* /*ioDiff*/) { // Default behavior is to not substitute at all return this; } // IntVal int GetIntVal(RefPtr val) { if (auto constantVal = val.As()) { return constantVal->value; } assert(!"unexpected"); return 0; } // ConstantIntVal bool ConstantIntVal::EqualsVal(Val* val) { if (auto intVal = dynamic_cast(val)) return value == intVal->value; return false; } String ConstantIntVal::ToString() { return String(value); } int ConstantIntVal::GetHashCode() { return value; } // SwitchStmt RefPtr SwitchStmt::Accept(SyntaxVisitor * visitor) { return visitor->VisitSwitchStmt(this); } RefPtr CaseStmt::Accept(SyntaxVisitor * visitor) { return visitor->VisitCaseStmt(this); } RefPtr DefaultStmt::Accept(SyntaxVisitor * visitor) { return visitor->VisitDefaultStmt(this); } // InterfaceDecl RefPtr InterfaceDecl::Accept(SyntaxVisitor * visitor) { visitor->visitInterfaceDecl(this); return this; } // InheritanceDecl RefPtr InheritanceDecl::Accept(SyntaxVisitor * visitor) { visitor->visitInheritanceDecl(this); return this; } // SharedTypeExpr RefPtr SharedTypeExpr::Accept(SyntaxVisitor * visitor) { return visitor->VisitSharedTypeExpr(this); } // OperatorExpressionSyntaxNode #if 0 void OperatorExpressionSyntaxNode::SetOperator(RefPtr scope, Slang::Compiler::Operator op) { this->Operator = op; auto opExpr = new VarExpressionSyntaxNode(); opExpr->Variable = GetOperatorFunctionName(Operator); opExpr->scope = scope; opExpr->Position = this->Position; this->FunctionExpr = opExpr; } #endif RefPtr OperatorExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) { return visitor->VisitOperatorExpression(this); } // DeclGroup RefPtr DeclGroup::Accept(SyntaxVisitor * visitor) { visitor->VisitDeclGroup(this); return this; } // void RegisterBuiltinDecl( RefPtr decl, RefPtr modifier) { auto type = DeclRefType::Create(DeclRef(decl.Ptr(), nullptr)); ExpressionType::sBuiltinTypes[(int)modifier->tag] = type; } void RegisterMagicDecl( RefPtr decl, RefPtr modifier) { ExpressionType::sMagicDecls[modifier->name] = decl.Ptr(); } RefPtr findMagicDecl( String const& name) { return ExpressionType::sMagicDecls[name].GetValue(); } ExpressionType* ExpressionType::GetBool() { return sBuiltinTypes[(int)BaseType::Bool].GetValue().Ptr(); } ExpressionType* ExpressionType::GetFloat() { return sBuiltinTypes[(int)BaseType::Float].GetValue().Ptr(); } ExpressionType* ExpressionType::GetInt() { return sBuiltinTypes[(int)BaseType::Int].GetValue().Ptr(); } ExpressionType* ExpressionType::GetUInt() { return sBuiltinTypes[(int)BaseType::UInt].GetValue().Ptr(); } ExpressionType* ExpressionType::GetVoid() { return sBuiltinTypes[(int)BaseType::Void].GetValue().Ptr(); } ExpressionType* ExpressionType::getInitializerListType() { return initializerListType.Ptr(); } ExpressionType* ExpressionType::GetError() { return ExpressionType::Error.Ptr(); } // RefPtr UnparsedStmt::Accept(SyntaxVisitor * visitor) { return this; } // RefPtr InitializerListExpr::Accept(SyntaxVisitor * visitor) { return visitor->visitInitializerListExpr(this); } // RefPtr ModifierDecl::Accept(SyntaxVisitor * visitor) { return this; } // RefPtr EmptyDecl::Accept(SyntaxVisitor * visitor) { return this; } // SyntaxNodeBase* createInstanceOfSyntaxClassByName( String const& name) { if(0) {} #define CASE(NAME) \ else if(name == #NAME) return new NAME() CASE(GLSLBufferModifier); CASE(GLSLWriteOnlyModifier); CASE(GLSLReadOnlyModifier); CASE(GLSLPatchModifier); CASE(SimpleModifier); #undef CASE else { assert(!"unexpected"); return nullptr; } } IntrinsicOp findIntrinsicOp(char const* name) { // TODO: need to make this faster by using a dictionary... if (0) {} #define INTRINSIC(NAME) else if(strcmp(name, #NAME) == 0) return IntrinsicOp::NAME; #include "intrinsic-defs.h" return IntrinsicOp::Unknown; } } }