From d5e2319c33115d0241dd9d2047c0a5f029553dde Mon Sep 17 00:00:00 2001 From: "YONGH\\yongh" Date: Thu, 2 Nov 2017 19:21:15 -0400 Subject: work inprogress --- source/slang/check.cpp | 53 +++++++++++++++--------- source/slang/lookup.cpp | 21 ++++++++++ source/slang/parser.cpp | 25 +++++++++++- source/slang/syntax.cpp | 36 ++-------------- tests/compute/assoctype-complex.slang | 59 +++++++++++++++++++++++++++ tests/compute/assoctype-complex.slang._ignore | 44 -------------------- tests/compute/generics-constraint1.slang | 17 ++++++++ 7 files changed, 157 insertions(+), 98 deletions(-) create mode 100644 tests/compute/assoctype-complex.slang delete mode 100644 tests/compute/assoctype-complex.slang._ignore create mode 100644 tests/compute/generics-constraint1.slang diff --git a/source/slang/check.cpp b/source/slang/check.cpp index 2f4ed2ce8..dc4e48545 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -169,23 +169,18 @@ namespace Slang sexpr->declRef = declRef; expr = sexpr; } - if (auto assocTypeDeclRef = declRef.As()) + if (auto constraintType = expr->type->As()) { - if (auto genConstraintType = baseExpr->type->As()) - { - // if this is a reference from a generic parameter, we need to generate a AssocTypeDeclRefType type. - // for example, if we have an expression T.U where T:ISimple, and U is an associated type defined in ISimple. - // then this expression should evaluate to AssocTypeDeclRefType(T, U). - auto assocTypeDeclType = new AssocTypeDeclRefType(); - assocTypeDeclType->declRef = assocTypeDeclRef; - assocTypeDeclType->sourceType = genConstraintType->subType; - assocTypeDeclType->setSession(getSession()); - expr->type = QualType(getTypeType(assocTypeDeclType)); - } + if (baseExpr->type->As()) + constraintType->subType = baseExpr->type->As()->type; + else + constraintType->subType = baseExpr->type; + } - else if (auto funcDeclRef = declRef.As()) + + if (auto genConstraintType = baseExpr->type->As()) { - if (auto genConstraintType = baseExpr->type->As()) + if (auto funcDeclRef = declRef.As()) { // if this is call expression, propagate the source associated type to the result type auto funcType = expr->type->As(); @@ -201,9 +196,24 @@ namespace Slang newFuncType->setSession(funcType->getSession()); expr->type = QualType(newFuncType); } - + } + else if (auto assocTypeDeclRef = declRef.As()) + { + auto assocTypeDeclType = new AssocTypeDeclRefType(); + assocTypeDeclType->declRef = assocTypeDeclRef; + assocTypeDeclType->sourceType = genConstraintType->subType; + assocTypeDeclType->setSession(getSession()); + expr->type = QualType(getTypeType(assocTypeDeclType)); } } + else if (auto assocTypeDeclRef = declRef.As()) + { + auto assocTypeDeclType = new AssocTypeDeclRefType(); + assocTypeDeclType->declRef = assocTypeDeclRef; + assocTypeDeclType->sourceType = baseExpr->type; + assocTypeDeclType->setSession(getSession()); + expr->type = QualType(getTypeType(assocTypeDeclType)); + } return expr; } else @@ -1262,6 +1272,14 @@ namespace Slang checkDecl(genericDecl->inner); } + void visitGenericTypeConstraintDecl(GenericTypeConstraintDecl * genericConstraintDecl) + { + // check the type being inherited from + auto base = genericConstraintDecl->sup; + base = TranslateTypeNode(base); + genericConstraintDecl->sup = base; + } + void visitInheritanceDecl(InheritanceDecl* inheritanceDecl) { // check the type being inherited from @@ -1325,11 +1343,6 @@ namespace Slang // These are only used in the stdlib, so no checking is needed for now } - void visitGenericTypeConstraintDecl(GenericTypeConstraintDecl*) - { - // These are only used in the stdlib, so no checking is needed for now - } - void visitModifier(Modifier*) { // Do nothing with modifiers for now diff --git a/source/slang/lookup.cpp b/source/slang/lookup.cpp index c0cb657c4..b97dac560 100644 --- a/source/slang/lookup.cpp +++ b/source/slang/lookup.cpp @@ -453,6 +453,27 @@ void lookUpMemberImpl( } } } + else if (auto assocTypeDeclRefType = type->As()) + { + auto assocTypeDeclRef = assocTypeDeclRefType->declRef; + for (auto constraintDeclRef : getMembersOfType(assocTypeDeclRef)) + { + // The super-type in the constraint (e.g., `Foo` in `T : Foo`) + // will tell us a type we should use for lookup. + auto bound = GetSup(constraintDeclRef); + + // Go ahead and use the target type, with an appropriate breadcrumb + // to indicate that we indirected through a type constraint. + + BreadcrumbInfo breadcrumb; + breadcrumb.prev = inBreadcrumbs; + breadcrumb.kind = LookupResultItem::Breadcrumb::Kind::Constraint; + breadcrumb.declRef = constraintDeclRef; + + // TODO: Need to consider case where this might recurse infinitely. + lookUpMemberImpl(session, semantics, name, bound, ioResult, &breadcrumb); + } + } } LookupResult lookUpMember( diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp index 224450a66..bf85356db 100644 --- a/source/slang/parser.cpp +++ b/source/slang/parser.cpp @@ -2113,7 +2113,30 @@ namespace Slang auto nameToken = parser->ReadToken(TokenType::Identifier); assocTypeDecl->nameAndLoc = NameLoc(nameToken); assocTypeDecl->loc = nameToken.loc; - parseOptionalInheritanceClause(parser, assocTypeDecl.Ptr()); + if (AdvanceIf(parser, TokenType::Colon)) + { + while (!parser->tokenReader.IsAtEnd()) + { + auto paramConstraint = new GenericTypeConstraintDecl(); + parser->FillPosition(paramConstraint); + + auto paramType = DeclRefType::Create( + parser->getSession(), + DeclRef(assocTypeDecl, nullptr)); + + auto paramTypeExpr = new SharedTypeExpr(); + paramTypeExpr->loc = assocTypeDecl->loc; + paramTypeExpr->base.type = paramType; + paramTypeExpr->type = QualType(getTypeType(paramType)); + + paramConstraint->sub = TypeExp(paramTypeExpr); + paramConstraint->sup = parser->ParseTypeExp(); + + AddMember(assocTypeDecl, paramConstraint); + if (!AdvanceIf(parser, TokenType::Comma)) + break; + } + } parser->ReadToken(TokenType::Semicolon); return assocTypeDecl; } diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index fd7fc0e14..3e38955ba 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -961,7 +961,8 @@ void Type::accept(IValVisitor* visitor, void* extra) { if (!sourceType) return this; - if (auto parentDeclRefType = sourceType->As()) + auto substSourceType = sourceType->SubstituteImpl(subst, ioDiff); + if (auto parentDeclRefType = substSourceType.As()) { auto parentDeclRef = parentDeclRefType->declRef; DeclRef newParentDeclRef = parentDeclRef.As(); @@ -1045,38 +1046,7 @@ void Type::accept(IValVisitor* visitor, void* extra) RefPtr GenericConstraintDeclRefType::SubstituteImpl(Substitutions* subst, int* ioDiff) { - auto genParamDecl = subType.As()->declRef.As(); - // search for a substitution that might apply to us - for (auto s = subst; s; s = s->outer.Ptr()) - { - // the generic decl associated with the substitution list must be - // the generic decl that declared this parameter - auto genericDecl = s->genericDecl; - if (genericDecl != genParamDecl.getDecl()->ParentDecl) - continue; - int index = 0; - for (auto m : genericDecl->Members) - { - if (m.Ptr() == genParamDecl.getDecl()) - { - // We've found it, so return the corresponding specialization argument - (*ioDiff)++; - return s->args[index]; - } - else if (auto typeParam = m.As()) - { - index++; - } - else if (auto valParam = m.As()) - { - index++; - } - else - { - } - } - } - return this; + return subType->SubstituteImpl(subst, ioDiff); } int GenericConstraintDeclRefType::GetHashCode() diff --git a/tests/compute/assoctype-complex.slang b/tests/compute/assoctype-complex.slang new file mode 100644 index 000000000..de3f1a103 --- /dev/null +++ b/tests/compute/assoctype-complex.slang @@ -0,0 +1,59 @@ +//TEST(smoke, compute):COMPARE_COMPUTE:-xslang -use-ir +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out + +RWStructuredBuffer outputBuffer; +interface IBase +{ + associatedtype V; + V sub(V a0, V a1); +} +interface ISimple +{ + associatedtype U : IBase; + U.V add(U v0, U v1); +} + +struct Val : IBase +{ + typedef int V; + V sub(V a0, V a1) + { + return a0-a1; + } +}; + +struct Simple : ISimple +{ + typedef Val U; + Val.V add(U v0, U v1) + { + return v0.sub(4, v1.sub(1,2)); + } +}; +/* +__generic +T.U.V test(T simple, T.U v0, T.U v1) +{ + return simple.add(v0, v1); +} + +__generic +T test(T v0, T v1) +{ + return v0 + v1; +} +*/ +__generic +T test(T v0, T v1) +{ + return T(3.0); +} +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + //Simple s; + //Val v0, v1; + //float outVal = test(s, v0, v1); // == 1.0 + float outVal = test(1.0, 2.0); + outputBuffer[dispatchThreadID.x] = outVal; +} \ No newline at end of file diff --git a/tests/compute/assoctype-complex.slang._ignore b/tests/compute/assoctype-complex.slang._ignore deleted file mode 100644 index 3e590b2e0..000000000 --- a/tests/compute/assoctype-complex.slang._ignore +++ /dev/null @@ -1,44 +0,0 @@ -RWStructuredBuffer outputBuffer; -interface IBase -{ - associatedtype V; - V sub(V a0, V a1); -} -interface ISimple -{ - associatedtype U : IBase; - U.V add(U v0, U v1); -} - -struct Val : IBase -{ - typedef int V; - V sub(V a0, V a1) - { - return a0-a1; - } -}; - -struct Simple : ISimple -{ - typedef Val U; - Val.V add(U v0, U v1) - { - return v0.sub(4, v1.sub(1,2)); - } -}; - -__generic -T.U.V test(T simple, T.U v0, T.U v1) -{ - return simple.add(v0, v1); -} - -[numthreads(4, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) -{ - Simple s; - Val v0, v1; - float outVal = test(s, v0, v1); // == 1.0 - outputBuffer[0] = outVal; -} \ No newline at end of file diff --git a/tests/compute/generics-constraint1.slang b/tests/compute/generics-constraint1.slang new file mode 100644 index 000000000..ff90c1cc9 --- /dev/null +++ b/tests/compute/generics-constraint1.slang @@ -0,0 +1,17 @@ +//TEST(smoke, compute):COMPARE_COMPUTE:-xslang -use-ir +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out + +RWStructuredBuffer outputBuffer; + +__generic +T test(T v0, T v1) +{ + return T(3.0); +} + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + float outVal = test(1.0, 2.0); + outputBuffer[dispatchThreadID.x] = outVal; +} \ No newline at end of file -- cgit v1.2.3