summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-11-29 20:01:41 -0500
committerGitHub <noreply@github.com>2022-11-29 17:01:41 -0800
commitf5581786a1891cedb165adb1afe71fe34f26e030 (patch)
tree86da2f1acbaec920ac0c38349897b293b405c021 /source/slang/slang-ir-autodiff.cpp
parentaf7f40063dfed1c651d33b93956c7623a7d2c050 (diff)
Refactored reverse-mode implementation to use 4 separate passes. (#2531)
* Added partial implementation for reverse-mode * Fixing several compile and runtime errors. * Fixed several issues with reverse-mode passes. * Fixed more issues. Basic reverse-mode tests passing Co-authored-by: Edward Liu <shiqiu1105@gmail.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff.cpp')
-rw-r--r--source/slang/slang-ir-autodiff.cpp54
1 files changed, 54 insertions, 0 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 313760d85..b0dbf62fa 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -327,6 +327,60 @@ IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt inde
}
+
+void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
+{
+ parentFunc = func;
+
+ auto decor = func->findDecoration<IRDifferentiableTypeDictionaryDecoration>();
+ SLANG_RELEASE_ASSERT(decor);
+
+ // Build lookup dictionary for type witnesses.
+ for (auto child = decor->getFirstChild(); child; child = child->next)
+ {
+ if (auto item = as<IRDifferentiableTypeDictionaryItem>(child))
+ {
+ auto existingItem = differentiableWitnessDictionary.TryGetValue(item->getConcreteType());
+ if (existingItem)
+ {
+ *existingItem = item->getWitness();
+ }
+ else
+ {
+ differentiableWitnessDictionary.Add((IRType*)item->getConcreteType(), item->getWitness());
+ }
+ }
+ }
+}
+
+IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type)
+{
+ IRInst* foundResult = nullptr;
+ differentiableWitnessDictionary.TryGetValue(type, foundResult);
+ return foundResult;
+}
+
+IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key)
+{
+ if (auto conformance = lookUpConformanceForType(origType))
+ {
+ return _lookupWitness(builder, conformance, key);
+ }
+ return nullptr;
+}
+
+void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary()
+{
+ for (auto globalInst : sharedContext->moduleInst->getChildren())
+ {
+ if (auto pairType = as<IRDifferentialPairType>(globalInst))
+ {
+ differentiableWitnessDictionary.Add(pairType->getValueType(), pairType->getWitness());
+ }
+ }
+}
+
+
void stripAutoDiffDecorationsFromChildren(IRInst* parent)
{
for (auto inst : parent->getChildren())