diff options
| -rw-r--r-- | source/slang/check.cpp | 77 | ||||
| -rw-r--r-- | source/slang/diagnostic-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/parser.cpp | 102 | ||||
| -rw-r--r-- | source/slang/slang-stdlib.cpp | 31 | ||||
| -rw-r--r-- | source/slang/syntax.cpp | 12 | ||||
| -rw-r--r-- | source/slang/syntax.h | 61 | ||||
| -rw-r--r-- | tests/front-end/interface.slang | 65 |
7 files changed, 248 insertions, 102 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp index a1b7393fb..6ff8efe9e 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -546,6 +546,12 @@ namespace Slang // No conversion at all kConversionCost_None = 0, + // Conversions based on explicit sub-typing relationships are the cheapest + // + // TODO(tfoley): We will eventually need a discipline for ranking + // when two up-casts are comparable. + kConversionCost_CastToInterface = 50, + // Conversion that is lossless and keeps the "kind" of the value the same kConversionCost_RankPromotion = 100, @@ -901,6 +907,25 @@ namespace Slang } } + if (auto toDeclRefType = toType->As<DeclRefType>()) + { + auto toTypeDeclRef = toDeclRefType->declRef; + if (auto interfaceDeclRef = toTypeDeclRef.As<InterfaceDeclRef>()) + { + // Trying to convert to an interface type. + // + // We will allow this if the type conforms to the interface. + if (DoesTypeConformToInterface(fromType, interfaceDeclRef)) + { + if (outToExpr) + *outToExpr = CreateImplicitCastExpr(toType, fromExpr); + if (outCost) + *outCost = kConversionCost_CastToInterface; + return true; + } + } + } + // TODO: more cases! return false; @@ -1045,24 +1070,32 @@ namespace Slang return genericDecl; } - virtual void VisitTraitConformanceDecl(TraitConformanceDecl* conformanceDecl) override + virtual void visitInterfaceDecl(InterfaceDecl* decl) override + { + // TODO: do some actual checking of members here + } + + virtual void visitInheritanceDecl(InheritanceDecl* inheritanceDecl) override { - // check the type being conformed to - auto base = conformanceDecl->base; + // check the type being inherited from + auto base = inheritanceDecl->base; base = TranslateTypeNode(base); - conformanceDecl->base = base; + inheritanceDecl->base = base; + + // For now we only allow inheritance from interfaces, so + // we will validate that the type expression names an interface if(auto declRefType = base.type->As<DeclRefType>()) { - if(auto traitDeclRef = declRefType->declRef.As<TraitDeclRef>()) + if(auto interfaceDeclRef = declRefType->declRef.As<InterfaceDeclRef>()) { - conformanceDecl->traitDeclRef = traitDeclRef; return; } } - // We expected a trait here - getSink()->diagnose( conformanceDecl, Diagnostics::expectedATraitGot, base.type); + // If type expression didn't name an interface, we'll emit an error here + // TODO: deal with the case of an error in the type expression (don't cascade) + getSink()->diagnose( base.exp, Diagnostics::expectedAnInterfaceGot, base.type); } RefPtr<ConstantIntVal> checkConstantIntVal( @@ -2479,20 +2512,24 @@ namespace Slang vectorType->elementCount); } - bool DoesTypeConformToTrait( + bool DoesTypeConformToInterface( RefPtr<ExpressionType> type, - TraitDeclRef traitDeclRef) + InterfaceDeclRef interfaceDeclRef) { // for now look up a conformance member... if(auto declRefType = type->As<DeclRefType>()) { if( auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>() ) { - for( auto conformanceRef : aggTypeDeclRef.GetMembersOfType<TraitConformanceDeclRef>()) + for( auto inheritanceDeclRef : aggTypeDeclRef.GetMembersOfType<InheritanceDeclRef>()) { - EnsureDecl(conformanceRef.GetDecl()); + EnsureDecl(inheritanceDeclRef.GetDecl()); + + auto inheritedDeclRefType = inheritanceDeclRef.getBaseType()->As<DeclRefType>(); + if (!inheritedDeclRefType) + continue; - if(traitDeclRef.Equals(conformanceRef.GetTraitDeclRef())) + if(interfaceDeclRef.Equals(inheritedDeclRefType->declRef)) return true; } } @@ -2502,12 +2539,12 @@ namespace Slang return false; } - RefPtr<ExpressionType> TryJoinTypeWithTrait( + RefPtr<ExpressionType> TryJoinTypeWithInterface( RefPtr<ExpressionType> type, - TraitDeclRef traitDeclRef) + InterfaceDeclRef interfaceDeclRef) { // The most basic test here should be: does the type declare conformance to the trait. - if(DoesTypeConformToTrait(type, traitDeclRef)) + if(DoesTypeConformToInterface(type, interfaceDeclRef)) return type; // There is a more nuanced case if `type` is a builtin type, and we need to make it @@ -2590,18 +2627,18 @@ namespace Slang // HACK: trying to work trait types in here... if(auto leftDeclRefType = left->As<DeclRefType>()) { - if( auto leftTraitRef = leftDeclRefType->declRef.As<TraitDeclRef>() ) + if( auto leftInterfaceRef = leftDeclRefType->declRef.As<InterfaceDeclRef>() ) { // - return TryJoinTypeWithTrait(right, leftTraitRef); + return TryJoinTypeWithInterface(right, leftInterfaceRef); } } if(auto rightDeclRefType = right->As<DeclRefType>()) { - if( auto rightTraitRef = rightDeclRefType->declRef.As<TraitDeclRef>() ) + if( auto rightInterfaceRef = rightDeclRefType->declRef.As<InterfaceDeclRef>() ) { // - return TryJoinTypeWithTrait(left, rightTraitRef); + return TryJoinTypeWithInterface(left, rightInterfaceRef); } } diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h index 3f690a5da..ff93e51fb 100644 --- a/source/slang/diagnostic-defs.h +++ b/source/slang/diagnostic-defs.h @@ -270,7 +270,7 @@ DIAGNOSTIC(39999, Error, expectedAGeneric, "expected a generic when using '<...> DIAGNOSTIC(39999, Error, genericArgumentInferenceFailed, "could not specialize generic for arguments of type $0") DIAGNOSTIC(39999, Note, genericSignatureTried, "see declaration of $0") -DIAGNOSTIC(39999, Error, expectedATraitGot, "expected a trait, got '$0'") +DIAGNOSTIC(39999, Error, expectedAnInterfaceGot, "expected an interface, got '$0'") DIAGNOSTIC(39999, Error, ambiguousReference, "amiguous reference to '$0'"); diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp index edd77c9ae..9e76a68b9 100644 --- a/source/slang/parser.cpp +++ b/source/slang/parser.cpp @@ -120,6 +120,11 @@ namespace Slang ContainerDecl* containerDecl, TokenType closingToken); + // Parse the `{}`-delimeted body of an aggregate type declaration + static void parseAggTypeDeclBody( + Parser* parser, + AggTypeDeclBase* decl); + static RefPtr<Modifier> ParseOptSemantics( Parser* parser); @@ -1016,7 +1021,7 @@ namespace Slang } - static String GenerateName(Parser* parser, String const& base) + static String GenerateName(Parser* /*parser*/, String const& base) { // TODO: somehow mangle the name to avoid clashes return base; @@ -1439,8 +1444,6 @@ namespace Slang declGroupBuilder.addDecl(firstDecl); return declGroupBuilder.getResult(); - - return firstDecl; } // Otherwise we have multiple declarations in a sequence, and these @@ -1684,8 +1687,7 @@ namespace Slang ParseOptSemantics(parser, bufferVarDecl.Ptr()); // The declarations in the body belong to the data type. - parser->ReadToken(TokenType::LBrace); - ParseDeclBody(parser, bufferDataTypeDecl.Ptr(), TokenType::RBrace); + parseAggTypeDeclBody(parser, bufferDataTypeDecl.Ptr()); // All HLSL buffer declarations are "transparent" in that their // members are implicitly made visible in the parent scope. @@ -1833,8 +1835,7 @@ namespace Slang blockVarDecl->Type.exp = blockVarTypeExpr; // The declarations in the body belong to the data type. - parser->ReadToken(TokenType::LBrace); - ParseDeclBody(parser, blockDataTypeDecl.Ptr(), TokenType::RBrace); + parseAggTypeDeclBody(parser, blockDataTypeDecl.Ptr()); if( parser->LookAheadToken(TokenType::Identifier) ) { @@ -1964,48 +1965,47 @@ parser->ReadToken(TokenType::Comma); return decl; } - static RefPtr<Decl> ParseTraitConformanceDecl( - Parser* parser) - { - RefPtr<TraitConformanceDecl> decl = new TraitConformanceDecl(); - parser->FillPosition(decl.Ptr()); - parser->ReadToken("__conforms"); - - decl->base = parser->ParseTypeExp(); - - return decl; - } - - static RefPtr<ExtensionDecl> ParseExtensionDecl(Parser* parser) { RefPtr<ExtensionDecl> decl = new ExtensionDecl(); parser->FillPosition(decl.Ptr()); parser->ReadToken("__extension"); decl->targetType = parser->ParseTypeExp(); - parser->ReadToken(TokenType::LBrace); - ParseDeclBody(parser, decl.Ptr(), TokenType::RBrace); + + parseAggTypeDeclBody(parser, decl.Ptr()); + return decl; } - static RefPtr<TraitDecl> ParseTraitDecl(Parser* parser) + static void parseOptionalInheritanceClause(Parser* parser, AggTypeDecl* decl) { - RefPtr<TraitDecl> decl = new TraitDecl(); - parser->FillPosition(decl.Ptr()); - parser->ReadToken("__trait"); - decl->Name = parser->ReadToken(TokenType::Identifier); - if( AdvanceIf(parser, TokenType::Colon) ) { do { auto base = parser->ParseTypeExp(); - decl->bases.Add(base); + + auto inheritanceDecl = new InheritanceDecl(); + inheritanceDecl->Position = base.exp->Position; + inheritanceDecl->base = base; + + AddMember(decl, inheritanceDecl); + } while( AdvanceIf(parser, TokenType::Comma) ); } + } + + static RefPtr<InterfaceDecl> parseInterfaceDecl(Parser* parser) + { + RefPtr<InterfaceDecl> decl = new InterfaceDecl(); + parser->FillPosition(decl.Ptr()); + parser->ReadToken("interface"); + decl->Name = parser->ReadToken(TokenType::Identifier); + + parseOptionalInheritanceClause(parser, decl.Ptr()); + + parseAggTypeDeclBody(parser, decl.Ptr()); - parser->ReadToken(TokenType::LBrace); - ParseDeclBody(parser, decl.Ptr(), TokenType::RBrace); return decl; } @@ -2149,16 +2149,14 @@ parser->ReadToken(TokenType::Comma); decl = ParseHLSLBufferDecl(parser); else if (parser->LookAheadToken("__generic")) decl = ParseGenericDecl(parser); - else if (parser->LookAheadToken("__conforms")) - decl = ParseTraitConformanceDecl(parser); else if (parser->LookAheadToken("__extension")) decl = ParseExtensionDecl(parser); else if (parser->LookAheadToken("__init")) decl = ParseConstructorDecl(parser); else if (parser->LookAheadToken("__subscript")) decl = ParseSubscriptDecl(parser); - else if (parser->LookAheadToken("__trait")) - decl = ParseTraitDecl(parser); + else if (parser->LookAheadToken("interface")) + decl = parseInterfaceDecl(parser); else if(parser->LookAheadToken("__modifier")) decl = parseModifierDecl(parser); else if(parser->LookAheadToken("__import")) @@ -2251,6 +2249,27 @@ parser->ReadToken(TokenType::Comma); } } + // Parse the `{}`-delimeted body of an aggregate type declaration + static void parseAggTypeDeclBody( + Parser* parser, + AggTypeDeclBase* decl) + { + // TODO: the scope used for the body might need to be + // slightly specialized to deal with the complexity + // of how `this` works. + // + // Alternatively, that complexity can be pushed down + // to semantic analysis so that it doesn't clutter + // things here. + parser->PushScope(decl); + + parser->ReadToken(TokenType::LBrace); + ParseDeclBody(parser, decl, TokenType::RBrace); + + parser->PopScope(); + } + + void Parser::parseSourceFile(ProgramSyntaxNode* program) { if (outerScope) @@ -2281,9 +2300,15 @@ parser->ReadToken(TokenType::Comma); RefPtr<StructSyntaxNode> rs = new StructSyntaxNode(); FillPosition(rs.Ptr()); ReadToken("struct"); + + // TODO: support `struct` declaration without tag rs->Name = ReadToken(TokenType::Identifier); - ReadToken(TokenType::LBrace); - ParseDeclBody(this, rs.Ptr(), TokenType::RBrace); + + // We allow for an inheritance clause on a `struct` + // so that it can conform to interfaces. + parseOptionalInheritanceClause(this, rs.Ptr()); + + parseAggTypeDeclBody(this, rs.Ptr()); return rs; } @@ -2295,7 +2320,8 @@ parser->ReadToken(TokenType::Comma); ReadToken("class"); rs->Name = ReadToken(TokenType::Identifier); ReadToken(TokenType::LBrace); - ParseDeclBody(this, rs.Ptr(), TokenType::RBrace); + parseOptionalInheritanceClause(this, rs.Ptr()); + parseAggTypeDeclBody(this, rs.Ptr()); return rs; } diff --git a/source/slang/slang-stdlib.cpp b/source/slang/slang-stdlib.cpp index 494a32e4d..40c391bf4 100644 --- a/source/slang/slang-stdlib.cpp +++ b/source/slang/slang-stdlib.cpp @@ -239,22 +239,22 @@ __generic<T> __magic_type(HLSLLineStreamType) struct TriangleStream {}; // Note(tfoley): Trying to systematically add all the HLSL builtins // A type that can be used as an operand for builtins -__trait __BuiltinType {} +interface __BuiltinType {} // A type that can be used for arithmetic operations -__trait __BuiltinArithmeticType : __BuiltinType {} +interface __BuiltinArithmeticType : __BuiltinType {} // A type that logically has a sign (positive/negative/zero) -__trait __BuiltinSignedArithmeticType : __BuiltinArithmeticType {} +interface __BuiltinSignedArithmeticType : __BuiltinArithmeticType {} // A type that can represent integers -__trait __BuiltinIntegerType : __BuiltinArithmeticType {} +interface __BuiltinIntegerType : __BuiltinArithmeticType {} // A type that can represent non-integers -__trait __BuiltinRealType : __BuiltinArithmeticType {} +interface __BuiltinRealType : __BuiltinArithmeticType {} // A type that uses a floating-point representation -__trait __BuiltinFloatingPointType : __BuiltinRealType, __BuiltinSignedType {} +interface __BuiltinFloatingPointType : __BuiltinRealType, __BuiltinSignedArithmeticType {} // Try to terminate the current draw or dispatch call (HLSL SM 4.0) __intrinsic void abort(); @@ -1088,33 +1088,36 @@ namespace Slang for (int tt = 0; tt < kBaseTypeCount; ++tt) { EMIT_LINE_DIRECTIVE(); - sb << "__builtin_type(" << int(kBaseTypes[tt].tag) << ") struct " << kBaseTypes[tt].name << "\n{\n"; + sb << "__builtin_type(" << int(kBaseTypes[tt].tag) << ") struct " << kBaseTypes[tt].name; - // Declare trait conformances for this type + // Declare interface conformances for this type - sb << "__conforms __BuiltinType;\n"; + sb << "\n : __BuiltinType\n"; switch( kBaseTypes[tt].tag ) { case BaseType::Float: - sb << "__conforms __BuiltinFloatingPointType;\n"; - sb << "__conforms __BuiltinRealType;\n"; + sb << "\n , __BuiltinFloatingPointType\n"; + sb << "\n , __BuiltinRealType\n"; // fall through to: case BaseType::Int: - sb << "__conforms __BuiltinSignedArithmeticType;\n"; + sb << "\n , __BuiltinSignedArithmeticType\n"; // fall through to: case BaseType::UInt: case BaseType::UInt64: - sb << "__conforms __BuiltinArithmeticType;\n"; + sb << "\n , __BuiltinArithmeticType\n"; // fall through to: case BaseType::Bool: - sb << "__conforms __BuiltinType;\n"; + sb << "\n , __BuiltinType\n"; break; default: break; } + sb << "\n{\n"; + + // Declare initializers to convert from various other types for( int ss = 0; ss < kBaseTypeCount; ++ss ) { diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index fffc6c725..e47c610c0 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -1311,19 +1311,19 @@ namespace Slang return visitor->VisitDefaultStmt(this); } - // TraitDecl + // InterfaceDecl - RefPtr<SyntaxNode> TraitDecl::Accept(SyntaxVisitor * visitor) + RefPtr<SyntaxNode> InterfaceDecl::Accept(SyntaxVisitor * visitor) { - visitor->VisitTraitDecl(this); + visitor->visitInterfaceDecl(this); return this; } - // TraitConformanceDecl + // InheritanceDecl - RefPtr<SyntaxNode> TraitConformanceDecl::Accept(SyntaxVisitor * visitor) + RefPtr<SyntaxNode> InheritanceDecl::Accept(SyntaxVisitor * visitor) { - visitor->VisitTraitConformanceDecl(this); + visitor->visitInheritanceDecl(this); return this; } diff --git a/source/slang/syntax.h b/source/slang/syntax.h index ac2bcac87..642f4e99f 100644 --- a/source/slang/syntax.h +++ b/source/slang/syntax.h @@ -1453,8 +1453,26 @@ namespace Slang SLANG_DECLARE_DECL_REF(StructField) }; + // An `AggTypeDeclBase` captures the shared functionality + // between true aggregate type declarations and extension + // declarations: + // + // - Both can container members (they are `ContainerDecl`s) + // - Both can have declared bases + // - Both expose a `this` variable in their body + // + class AggTypeDeclBase : public ContainerDecl + { + public: + }; + + struct AggTypeDeclBaseRef : ContainerDeclRef + { + SLANG_DECLARE_DECL_REF(AggTypeDeclBase); + }; + // An extension to apply to an existing type - class ExtensionDecl : public ContainerDecl + class ExtensionDecl : public AggTypeDeclBase { public: TypeExp targetType; @@ -1466,7 +1484,7 @@ namespace Slang virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; }; - struct ExtensionDeclRef : ContainerDeclRef + struct ExtensionDeclRef : AggTypeDeclBaseRef { SLANG_DECLARE_DECL_REF(ExtensionDecl); @@ -1474,9 +1492,10 @@ namespace Slang }; // Declaration of a type that represents some sort of aggregate - class AggTypeDecl : public ContainerDecl + class AggTypeDecl : public AggTypeDeclBase { public: + // extensions that might apply to this declaration ExtensionDecl* candidateExtensions = nullptr; FilteredMemberList<StructField> GetFields() @@ -1505,7 +1524,7 @@ namespace Slang } }; - struct AggTypeDeclRef : public ContainerDeclRef + struct AggTypeDeclRef : public AggTypeDeclBaseRef { SLANG_DECLARE_DECL_REF(AggTypeDecl); @@ -1538,43 +1557,41 @@ namespace Slang FilteredMemberRefList<FieldDeclRef> GetFields() const { return GetMembersOfType<FieldDeclRef>(); } }; - // A trait which other types can conform to - class TraitDecl : public AggTypeDecl + // An interface which other types can conform to + class InterfaceDecl : public AggTypeDecl { public: - List<TypeExp> bases; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; }; - struct TraitDeclRef : public AggTypeDeclRef + struct InterfaceDeclRef : public AggTypeDeclRef { - SLANG_DECLARE_DECL_REF(TraitDecl); + SLANG_DECLARE_DECL_REF(InterfaceDecl); }; - // A declaration that states that the enclosing type supports a given trait + // A kind of pseudo-member that represents an explicit + // or implicit inheritance relationship. // - // TODO: this same construct might be used for represent other inheritance-like cases - class TraitConformanceDecl : public Decl + class InheritanceDecl : public Decl { public: // The type expression as written TypeExp base; - // The trait that we found we conform to... - TraitDeclRef traitDeclRef; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; }; - struct TraitConformanceDeclRef : public DeclRef + struct InheritanceDeclRef : public DeclRef { - SLANG_DECLARE_DECL_REF(TraitConformanceDecl); + SLANG_DECLARE_DECL_REF(InheritanceDecl); - TraitDeclRef GetTraitDeclRef() { return Substitute(GetDecl()->traitDeclRef).As<TraitDeclRef>(); } + RefPtr<ExpressionType> getBaseType() { return Substitute(GetDecl()->base.type); } }; + // TODO: may eventually need sub-classes for explicit/direct vs. implicit/indirect inheritance + + // A declaration that represents a simple (non-aggregate) type class SimpleTypeDecl : public Decl { @@ -2744,11 +2761,9 @@ namespace Slang virtual void visitAccessorDecl(AccessorDecl* decl) = 0; - virtual void VisitTraitDecl(TraitDecl* /*decl*/) - {} + virtual void visitInterfaceDecl(InterfaceDecl* /*decl*/) = 0; - virtual void VisitTraitConformanceDecl(TraitConformanceDecl* /*decl*/) - {} + virtual void visitInheritanceDecl(InheritanceDecl* /*decl*/) = 0; virtual RefPtr<ExpressionSyntaxNode> VisitSharedTypeExpr(SharedTypeExpr* typeExpr) { diff --git a/tests/front-end/interface.slang b/tests/front-end/interface.slang new file mode 100644 index 000000000..754addf61 --- /dev/null +++ b/tests/front-end/interface.slang @@ -0,0 +1,65 @@ +//TEST:SIMPLE: + +// Confirm that basic `interface` syntax stuff type-checks. + +// The example here is adapted from examples in Matt Pharr's +// chapter in GPU Gems: "An Introduction to Shader Interaces" + +struct LightSample +{ + float3 C; // radiance + float3 L; // direction +}; + +interface Light +{ + LightSample illuminate(float3 P_world); +} + +struct PointLight : Light +{ + float3 Plight_world; + float3 C; + + LightSample illuminate(float3 P_world) + { + float3 delta = Plight_world - P_world; + float3 L = normalize(delta); + float distance = length(delta); + + LightSample result; + result.L = L; + result.C = C * (1 / (distance*distance)); + return result; + } +}; + +// using the concrete type directly +float3 A( float3 P_world, PointLight light ) +{ + return light.illuminate(P_world).L; +} + +// using the abstract interface type +float3 B( float3 P_world, Light light ) +{ + return light.illuminate(P_world).L; +} + +// +float3 Test(float3 P_world, PointLight pointLight, Light light) +{ + // dconcrete type expected, concrete type provided + float3 a = A(P_world, pointLight); + + // abstract type expected, abstract type provided + float3 b = B(P_world, light); + + // abstract type expected, concrete type provided + float3 c = B(P_world, pointLight); + + // The remaining case (passing `Light` where `PointLight` is expected) + // should be an error, so we want a distinct test for that. + + return a + b + c; +}
\ No newline at end of file |
