summaryrefslogtreecommitdiff
path: root/source/slang/slang-ast-val.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-10-20 14:22:00 -0400
committerGitHub <noreply@github.com>2022-10-20 11:22:00 -0700
commit1093218d6f0e114eb9fa52d60ca525bf9dd9f98a (patch)
treee85158637680f783caaf7f4433a6844398cd8f7b /source/slang/slang-ast-val.cpp
parent576c8407e60143682cd40c68101c6eae8563ca3d (diff)
Modified the new type system to support generic differentiable types … (#2413)
* Modified the new type system to support generic differentiable types and added support for differentiating overloaded functions. * Changed a few asserts to release asserts to avoid unreferenced variable errors * Fixed a naming issue with TypeWitnessBreadcumb::Flavor::Decl * Added logic to avoid tracking differentiable types if the module does not use auto-diff or define differentiable types. * Moved the auto-diff passes to after the specialization step, added a more complex generics test * Added a generics stress test and fixed AST-side logic. IR side needs some more work * Added differential getter and setter logic, fixed multiple issues with DifferentiableTypeDictionary, added support for loops and conditions * Changed differential getters to use pointer types, added getter type checking * Fixed some bugs related to diff type registration and differential getters * Removed some superfluous code * Removed some more unused code. * Fixed an issue with witness substitution * Minor fix Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ast-val.cpp')
-rw-r--r--source/slang/slang-ast-val.cpp85
1 files changed, 61 insertions, 24 deletions
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index 377dee350..a8ceaa716 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -283,7 +283,7 @@ Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, Sub
{
if (constraintParam == declRef.getDecl())
{
- found = true;
+ found = true;
break;
}
index++;
@@ -443,6 +443,66 @@ HashCode TransitiveSubtypeWitness::_getHashCodeOverride()
return hash;
}
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractFromConjunctionSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff)
+{
+ int diff = 0;
+
+ Type* substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff));
+ Type* substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff));
+ SubtypeWitness* substWitness = as<SubtypeWitness>(conjunctionWitness->substituteImpl(astBuilder, subst, &diff));
+
+ // If nothing changed, then we can bail out early.
+ if (!diff)
+ return this;
+
+ // Something changes, so let the caller know.
+ (*ioDiff)++;
+
+ // If the substituted witness is a conjunction, break it apart, but it's important to replace the
+ // sub and super types with the current ones since the conjunction witness will have an
+ //
+ if (auto substConjunctionWitness = as<ConjunctionSubtypeWitness>(substWitness))
+ {
+ if (indexInConjunction == 0)
+ {
+ auto witness = as<SubtypeWitness>(substConjunctionWitness->leftWitness);
+ SLANG_ASSERT(witness);
+
+ witness->sub = substSub;
+ witness->sup = substSup;
+
+ return witness;
+ }
+ else if (indexInConjunction == 1)
+ {
+ auto witness = as<SubtypeWitness>(substConjunctionWitness->rightWitness);
+ SLANG_ASSERT(witness);
+
+ witness->sub = substSub;
+ witness->sup = substSup;
+
+ return witness;
+ }
+ else
+ {
+ SLANG_UNIMPLEMENTED_X("conjunction index must be 0 or 1");
+ }
+ }
+ else
+ {
+ // In the simple case, we just construct a new conjunction subtype
+ // witness.
+ ExtractFromConjunctionSubtypeWitness* result = astBuilder->create<ExtractFromConjunctionSubtypeWitness>();
+ result->sub = substSub;
+ result->sup = substSup;
+ result->conjunctionWitness = substWitness;
+ result->indexInConjunction = indexInConjunction;
+ return result;
+ }
+}
+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
bool ExtractExistentialSubtypeWitness::_equalsValOverride(Val* val)
@@ -637,29 +697,6 @@ HashCode ExtractFromConjunctionSubtypeWitness::_getHashCodeOverride()
return combineHash(indexInConjunction, conjunctionWitness ? conjunctionWitness->getHashCode() : 0);
}
-Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
-{
- int diff = 0;
- Val* newConjunctionWitness = nullptr;
-
- auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff));
- auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff));
-
- if (this->conjunctionWitness)
- newConjunctionWitness = conjunctionWitness->substituteImpl(astBuilder, subst, &diff);
- *ioDiff += diff;
-
- if (diff)
- {
- auto result = astBuilder->create<ExtractFromConjunctionSubtypeWitness>();
- result->conjunctionWitness = newConjunctionWitness;
- result->sub = substSub;
- result->sup = substSup;
- return result;
- }
- return this;
-}
-
// ModifierVal
bool ModifierVal::_equalsValOverride(Val* val)