diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-23 09:39:08 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-23 09:39:08 -0800 |
| commit | 97cb4851eed7a43f10196971b08d3d311386ce9f (patch) | |
| tree | 99ba81368068b3345fa23b749108265aa753ed2b /source/slang/slang-ast-val.cpp | |
| parent | 6178cb601368e977c4aa82e0ae25b8eb1e875d84 (diff) | |
Autodiff through simple dynamic dispatch. (#2527)
* Autodiff through simple dynamic dispatch.
* Revert changes.
* Fix.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ast-val.cpp')
| -rw-r--r-- | source/slang/slang-ast-val.cpp | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index 87e89ef18..a0f0552c6 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -1516,4 +1516,44 @@ Val* WitnessLookupIntVal::tryFold(ASTBuilder* astBuilder, SubtypeWitness* witnes return witnessResult; } + +bool DifferentiateVal::_equalsValOverride(Val* val) +{ + if (auto other = as<DifferentiateVal>(val)) + { + return other->astNodeType == astNodeType && other->func == func; + } + return false; +} + +void DifferentiateVal::_toTextOverride(StringBuilder& out) +{ + out << "DifferentiateVal("; + out << func; + out << ")"; +} + +HashCode DifferentiateVal::_getHashCodeOverride() +{ + HashCode result = (HashCode)astNodeType; + result = combineHash(result, func.getHashCode()); + return result; +} + +Val* DifferentiateVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + auto newFunc = func.substituteImpl(astBuilder, subst, &diff); + *ioDiff += diff; + if (diff) + { + auto result = as<DifferentiateVal>(astBuilder->createByNodeType(astNodeType)); + result->func = newFunc; + return result; + } + // Nothing found: don't substitute. + return this; +} + + } // namespace Slang |
