From 1093218d6f0e114eb9fa52d60ca525bf9dd9f98a Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 20 Oct 2022 14:22:00 -0400 Subject: Modified the new type system to support generic differentiable types … (#2413) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Modified the new type system to support generic differentiable types and added support for differentiating overloaded functions. * Changed a few asserts to release asserts to avoid unreferenced variable errors * Fixed a naming issue with TypeWitnessBreadcumb::Flavor::Decl * Added logic to avoid tracking differentiable types if the module does not use auto-diff or define differentiable types. * Moved the auto-diff passes to after the specialization step, added a more complex generics test * Added a generics stress test and fixed AST-side logic. IR side needs some more work * Added differential getter and setter logic, fixed multiple issues with DifferentiableTypeDictionary, added support for loops and conditions * Changed differential getters to use pointer types, added getter type checking * Fixed some bugs related to diff type registration and differential getters * Removed some superfluous code * Removed some more unused code. * Fixed an issue with witness substitution * Minor fix Co-authored-by: Yong He --- source/slang/slang-ast-val.cpp | 85 ++++++++++++++++++++++++++++++------------ 1 file changed, 61 insertions(+), 24 deletions(-) (limited to 'source/slang/slang-ast-val.cpp') diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index 377dee350..a8ceaa716 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -283,7 +283,7 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub { if (constraintParam == declRef.getDecl()) { - found = true; + found = true; break; } index++; @@ -443,6 +443,66 @@ HashCode TransitiveSubtypeWitness::_getHashCodeOverride() return hash; } +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractFromConjunctionSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) +{ + int diff = 0; + + Type* substSub = as(sub->substituteImpl(astBuilder, subst, &diff)); + Type* substSup = as(sup->substituteImpl(astBuilder, subst, &diff)); + SubtypeWitness* substWitness = as(conjunctionWitness->substituteImpl(astBuilder, subst, &diff)); + + // If nothing changed, then we can bail out early. + if (!diff) + return this; + + // Something changes, so let the caller know. + (*ioDiff)++; + + // If the substituted witness is a conjunction, break it apart, but it's important to replace the + // sub and super types with the current ones since the conjunction witness will have an + // + if (auto substConjunctionWitness = as(substWitness)) + { + if (indexInConjunction == 0) + { + auto witness = as(substConjunctionWitness->leftWitness); + SLANG_ASSERT(witness); + + witness->sub = substSub; + witness->sup = substSup; + + return witness; + } + else if (indexInConjunction == 1) + { + auto witness = as(substConjunctionWitness->rightWitness); + SLANG_ASSERT(witness); + + witness->sub = substSub; + witness->sup = substSup; + + return witness; + } + else + { + SLANG_UNIMPLEMENTED_X("conjunction index must be 0 or 1"); + } + } + else + { + // In the simple case, we just construct a new conjunction subtype + // witness. + ExtractFromConjunctionSubtypeWitness* result = astBuilder->create(); + result->sub = substSub; + result->sup = substSup; + result->conjunctionWitness = substWitness; + result->indexInConjunction = indexInConjunction; + return result; + } +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! bool ExtractExistentialSubtypeWitness::_equalsValOverride(Val* val) @@ -637,29 +697,6 @@ HashCode ExtractFromConjunctionSubtypeWitness::_getHashCodeOverride() return combineHash(indexInConjunction, conjunctionWitness ? conjunctionWitness->getHashCode() : 0); } -Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) -{ - int diff = 0; - Val* newConjunctionWitness = nullptr; - - auto substSub = as(sub->substituteImpl(astBuilder, subst, &diff)); - auto substSup = as(sup->substituteImpl(astBuilder, subst, &diff)); - - if (this->conjunctionWitness) - newConjunctionWitness = conjunctionWitness->substituteImpl(astBuilder, subst, &diff); - *ioDiff += diff; - - if (diff) - { - auto result = astBuilder->create(); - result->conjunctionWitness = newConjunctionWitness; - result->sub = substSub; - result->sup = substSup; - return result; - } - return this; -} - // ModifierVal bool ModifierVal::_equalsValOverride(Val* val) -- cgit v1.2.3