diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-11-29 20:01:41 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-29 17:01:41 -0800 |
| commit | f5581786a1891cedb165adb1afe71fe34f26e030 (patch) | |
| tree | 86da2f1acbaec920ac0c38349897b293b405c021 /source/slang/slang-ir-autodiff.cpp | |
| parent | af7f40063dfed1c651d33b93956c7623a7d2c050 (diff) | |
Refactored reverse-mode implementation to use 4 separate passes. (#2531)
* Added partial implementation for reverse-mode
* Fixing several compile and runtime errors.
* Fixed several issues with reverse-mode passes.
* Fixed more issues. Basic reverse-mode tests passing
Co-authored-by: Edward Liu <shiqiu1105@gmail.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 54 |
1 files changed, 54 insertions, 0 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 313760d85..b0dbf62fa 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -327,6 +327,60 @@ IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt inde } + +void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) +{ + parentFunc = func; + + auto decor = func->findDecoration<IRDifferentiableTypeDictionaryDecoration>(); + SLANG_RELEASE_ASSERT(decor); + + // Build lookup dictionary for type witnesses. + for (auto child = decor->getFirstChild(); child; child = child->next) + { + if (auto item = as<IRDifferentiableTypeDictionaryItem>(child)) + { + auto existingItem = differentiableWitnessDictionary.TryGetValue(item->getConcreteType()); + if (existingItem) + { + *existingItem = item->getWitness(); + } + else + { + differentiableWitnessDictionary.Add((IRType*)item->getConcreteType(), item->getWitness()); + } + } + } +} + +IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type) +{ + IRInst* foundResult = nullptr; + differentiableWitnessDictionary.TryGetValue(type, foundResult); + return foundResult; +} + +IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key) +{ + if (auto conformance = lookUpConformanceForType(origType)) + { + return _lookupWitness(builder, conformance, key); + } + return nullptr; +} + +void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary() +{ + for (auto globalInst : sharedContext->moduleInst->getChildren()) + { + if (auto pairType = as<IRDifferentialPairType>(globalInst)) + { + differentiableWitnessDictionary.Add(pairType->getValueType(), pairType->getWitness()); + } + } +} + + void stripAutoDiffDecorationsFromChildren(IRInst* parent) { for (auto inst : parent->getChildren()) |
