diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-10-20 14:22:00 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-20 11:22:00 -0700 |
| commit | 1093218d6f0e114eb9fa52d60ca525bf9dd9f98a (patch) | |
| tree | e85158637680f783caaf7f4433a6844398cd8f7b /source/slang/slang-ast-val.cpp | |
| parent | 576c8407e60143682cd40c68101c6eae8563ca3d (diff) | |
Modified the new type system to support generic differentiable types … (#2413)
* Modified the new type system to support generic differentiable types and added support for differentiating overloaded functions.
* Changed a few asserts to release asserts to avoid unreferenced variable errors
* Fixed a naming issue with TypeWitnessBreadcumb::Flavor::Decl
* Added logic to avoid tracking differentiable types if the module does not use auto-diff or define differentiable types.
* Moved the auto-diff passes to after the specialization step, added a more complex generics test
* Added a generics stress test and fixed AST-side logic. IR side needs some more work
* Added differential getter and setter logic, fixed multiple issues with DifferentiableTypeDictionary, added support for loops and conditions
* Changed differential getters to use pointer types, added getter type checking
* Fixed some bugs related to diff type registration and differential getters
* Removed some superfluous code
* Removed some more unused code.
* Fixed an issue with witness substitution
* Minor fix
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ast-val.cpp')
| -rw-r--r-- | source/slang/slang-ast-val.cpp | 85 |
1 files changed, 61 insertions, 24 deletions
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index 377dee350..a8ceaa716 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -283,7 +283,7 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub { if (constraintParam == declRef.getDecl()) { - found = true; + found = true; break; } index++; @@ -443,6 +443,66 @@ HashCode TransitiveSubtypeWitness::_getHashCodeOverride() return hash; } +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractFromConjunctionSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) +{ + int diff = 0; + + Type* substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff)); + Type* substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff)); + SubtypeWitness* substWitness = as<SubtypeWitness>(conjunctionWitness->substituteImpl(astBuilder, subst, &diff)); + + // If nothing changed, then we can bail out early. + if (!diff) + return this; + + // Something changes, so let the caller know. + (*ioDiff)++; + + // If the substituted witness is a conjunction, break it apart, but it's important to replace the + // sub and super types with the current ones since the conjunction witness will have an + // + if (auto substConjunctionWitness = as<ConjunctionSubtypeWitness>(substWitness)) + { + if (indexInConjunction == 0) + { + auto witness = as<SubtypeWitness>(substConjunctionWitness->leftWitness); + SLANG_ASSERT(witness); + + witness->sub = substSub; + witness->sup = substSup; + + return witness; + } + else if (indexInConjunction == 1) + { + auto witness = as<SubtypeWitness>(substConjunctionWitness->rightWitness); + SLANG_ASSERT(witness); + + witness->sub = substSub; + witness->sup = substSup; + + return witness; + } + else + { + SLANG_UNIMPLEMENTED_X("conjunction index must be 0 or 1"); + } + } + else + { + // In the simple case, we just construct a new conjunction subtype + // witness. + ExtractFromConjunctionSubtypeWitness* result = astBuilder->create<ExtractFromConjunctionSubtypeWitness>(); + result->sub = substSub; + result->sup = substSup; + result->conjunctionWitness = substWitness; + result->indexInConjunction = indexInConjunction; + return result; + } +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! bool ExtractExistentialSubtypeWitness::_equalsValOverride(Val* val) @@ -637,29 +697,6 @@ HashCode ExtractFromConjunctionSubtypeWitness::_getHashCodeOverride() return combineHash(indexInConjunction, conjunctionWitness ? conjunctionWitness->getHashCode() : 0); } -Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) -{ - int diff = 0; - Val* newConjunctionWitness = nullptr; - - auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff)); - auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff)); - - if (this->conjunctionWitness) - newConjunctionWitness = conjunctionWitness->substituteImpl(astBuilder, subst, &diff); - *ioDiff += diff; - - if (diff) - { - auto result = astBuilder->create<ExtractFromConjunctionSubtypeWitness>(); - result->conjunctionWitness = newConjunctionWitness; - result->sub = substSub; - result->sup = substSup; - return result; - } - return this; -} - // ModifierVal bool ModifierVal::_equalsValOverride(Val* val) |
