From 1093218d6f0e114eb9fa52d60ca525bf9dd9f98a Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 20 Oct 2022 14:22:00 -0400 Subject: Modified the new type system to support generic differentiable types … (#2413) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Modified the new type system to support generic differentiable types and added support for differentiating overloaded functions. * Changed a few asserts to release asserts to avoid unreferenced variable errors * Fixed a naming issue with TypeWitnessBreadcumb::Flavor::Decl * Added logic to avoid tracking differentiable types if the module does not use auto-diff or define differentiable types. * Moved the auto-diff passes to after the specialization step, added a more complex generics test * Added a generics stress test and fixed AST-side logic. IR side needs some more work * Added differential getter and setter logic, fixed multiple issues with DifferentiableTypeDictionary, added support for loops and conditions * Changed differential getters to use pointer types, added getter type checking * Fixed some bugs related to diff type registration and differential getters * Removed some superfluous code * Removed some more unused code. * Fixed an issue with witness substitution * Minor fix Co-authored-by: Yong He --- source/slang/slang-check-conformance.cpp | 139 ++++++++++++++++++++++++++++--- 1 file changed, 126 insertions(+), 13 deletions(-) (limited to 'source/slang/slang-check-conformance.cpp') diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index e0c1f3702..cf362dcdd 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -18,6 +18,62 @@ namespace Slang return witness; } + + Val* simplifyWitness(ASTBuilder* builder, Val* witness) + { + if (auto extractFromConjunction = as(witness)) + { + auto simplWitness = simplifyWitness(builder, extractFromConjunction->conjunctionWitness); + if (auto conjunction = as(simplWitness)) + { + auto index = extractFromConjunction->indexInConjunction; + SLANG_ASSERT(index == 0 || index == 1); + if (index == 0) + return conjunction->leftWitness; + else + return conjunction->rightWitness; + } + + ExtractFromConjunctionSubtypeWitness* simplExtractFromConjunction = builder->create(); + simplExtractFromConjunction->sub = extractFromConjunction->sub; + simplExtractFromConjunction->sup = extractFromConjunction->sup; + simplExtractFromConjunction->indexInConjunction = extractFromConjunction->indexInConjunction; + simplExtractFromConjunction->conjunctionWitness = as(simplWitness); + + return simplExtractFromConjunction; + } + else if (auto conjunctionWitness = as(witness)) + { + auto simplConjunctionWitness = builder->create(); + simplConjunctionWitness->leftWitness = as(simplifyWitness(builder, conjunctionWitness->leftWitness)); + simplConjunctionWitness->rightWitness = as(simplifyWitness(builder, conjunctionWitness->rightWitness)); + simplConjunctionWitness->sub = conjunctionWitness->sub; + simplConjunctionWitness->sup = conjunctionWitness->sup; + + return simplConjunctionWitness; + } + else if (auto transitiveWitness = as(witness)) + { + TransitiveSubtypeWitness* simplTransitiveWitness = builder->getOrCreateWithDefaultCtor( + transitiveWitness->sub, + transitiveWitness->sup, + transitiveWitness->midToSup); + + simplTransitiveWitness->sub = transitiveWitness->sub; + simplTransitiveWitness->sup = transitiveWitness->sup; + simplTransitiveWitness->midToSup = as(simplifyWitness(builder, transitiveWitness->midToSup)); + simplTransitiveWitness->subToMid = as(simplifyWitness(builder, transitiveWitness->subToMid)); + + return simplTransitiveWitness; + } + else + { + // TODO: Add other cases. + return witness; + } + } + + Val* SemanticsVisitor::createTypeWitness( Type* subType, DeclRef superTypeDeclRef, @@ -70,7 +126,7 @@ namespace Slang // As long as there is more than one breadcrumb, we // need to be creating transitive witnesses. - while(bb->prev) + while (bb->prev) { // On the first iteration when processing the list // above, the breadcrumb would be for `{ C : D }`, @@ -83,19 +139,42 @@ namespace Slang // where `[...]` represents the "hole" we leave // open to fill in next. // - DeclaredSubtypeWitness* declaredWitness = - m_astBuilder->getOrCreate( - bb->sub, bb->sup, bb->declRef.decl, bb->declRef.substitutions.substitutions); + if (bb->flavor == TypeWitnessBreadcrumb::Flavor::DeclFlavor) + { + DeclaredSubtypeWitness* declaredWitness = + m_astBuilder->getOrCreate( + bb->sub, bb->sup, bb->declRef.decl, bb->declRef.substitutions.substitutions); + + TransitiveSubtypeWitness* transitiveWitness = m_astBuilder->getOrCreateWithDefaultCtor(subType, bb->sup, declaredWitness); + transitiveWitness->sub = subType; + transitiveWitness->sup = bb->sup; + transitiveWitness->midToSup = declaredWitness; + + // Fill in the current hole, and then set the + // hole to point into the node we just created. + *link = transitiveWitness; + link = &transitiveWitness->subToMid; + } + else if(bb->flavor == TypeWitnessBreadcrumb::Flavor::AndTypeLeftFlavor) + { + ExtractFromConjunctionSubtypeWitness* extractWitness = m_astBuilder->create(); + extractWitness->sub = subType; + extractWitness->sup = bb->sup; + extractWitness->indexInConjunction = 0; - TransitiveSubtypeWitness* transitiveWitness = m_astBuilder->getOrCreateWithDefaultCtor(subType, bb->sup, declaredWitness); - transitiveWitness->sub = subType; - transitiveWitness->sup = bb->sup; - transitiveWitness->midToSup = declaredWitness; + *link = extractWitness; + link = (SubtypeWitness**) &extractWitness->conjunctionWitness; + } + else if(bb->flavor == TypeWitnessBreadcrumb::Flavor::AndTypeRightFlavor) + { + ExtractFromConjunctionSubtypeWitness* extractWitness = m_astBuilder->create(); + extractWitness->sub = subType; + extractWitness->sup = bb->sup; + extractWitness->indexInConjunction = 1; - // Fill in the current hole, and then set the - // hole to point into the node we just created. - *link = transitiveWitness; - link = &transitiveWitness->subToMid; + *link = extractWitness; + link = (SubtypeWitness**) &extractWitness->conjunctionWitness; + } // Move on with the list. bb = bb->prev; @@ -108,9 +187,14 @@ namespace Slang DeclaredSubtypeWitness* declaredWitness = createSimpleSubtypeWitness(bb); *link = declaredWitness; + // Simplify witnesses of the form ExtractFromConjunction(ConjunctionWitness(...)) + // TODO: At some point, we need a more robust way of checking that two witnesses are in-fact 'equal'. + // In the meantime, this step should suffice. + + // We now know that our original `witness` variable has been // filled in, and there are no other holes. - return witness; + return simplifyWitness(m_astBuilder, witness); } bool SemanticsVisitor::isInterfaceSafeForTaggedUnion( @@ -379,6 +463,35 @@ namespace Slang } return true; } + else if (auto andType = as(subType)) + { + // (L & R) is a subtype of T if either L or R is a subtype of T. + // Note that in this method T is explicitly a DeclRef and so cannot be a conjunction itself. + // + TypeWitnessBreadcrumb leftBreadcrumb; + leftBreadcrumb.prev = inBreadcrumbs; + leftBreadcrumb.sub = andType; + leftBreadcrumb.sup = DeclRefType::create(m_astBuilder, superTypeDeclRef); + leftBreadcrumb.declRef = makeDeclRef((Decl*)nullptr); + leftBreadcrumb.flavor = TypeWitnessBreadcrumb::Flavor::AndTypeLeftFlavor; + + if(_isDeclaredSubtype(originalSubType, andType->left, superTypeDeclRef, outWitness, &leftBreadcrumb)) + { + return true; + } + + TypeWitnessBreadcrumb rightBreadcrumb; + rightBreadcrumb.prev = inBreadcrumbs; + rightBreadcrumb.sub = andType; + rightBreadcrumb.sup = DeclRefType::create(m_astBuilder, superTypeDeclRef); + rightBreadcrumb.declRef = makeDeclRef((Decl*)nullptr); + rightBreadcrumb.flavor = TypeWitnessBreadcrumb::Flavor::AndTypeRightFlavor; + + if(_isDeclaredSubtype(originalSubType, andType->right, superTypeDeclRef, outWitness, &rightBreadcrumb)) + { + return true; + } + } // default is failure return false; } -- cgit v1.2.3