summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-diff-call.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-10-20 14:22:00 -0400
committerGitHub <noreply@github.com>2022-10-20 11:22:00 -0700
commit1093218d6f0e114eb9fa52d60ca525bf9dd9f98a (patch)
treee85158637680f783caaf7f4433a6844398cd8f7b /source/slang/slang-ir-diff-call.cpp
parent576c8407e60143682cd40c68101c6eae8563ca3d (diff)
Modified the new type system to support generic differentiable types … (#2413)
* Modified the new type system to support generic differentiable types and added support for differentiating overloaded functions. * Changed a few asserts to release asserts to avoid unreferenced variable errors * Fixed a naming issue with TypeWitnessBreadcumb::Flavor::Decl * Added logic to avoid tracking differentiable types if the module does not use auto-diff or define differentiable types. * Moved the auto-diff passes to after the specialization step, added a more complex generics test * Added a generics stress test and fixed AST-side logic. IR side needs some more work * Added differential getter and setter logic, fixed multiple issues with DifferentiableTypeDictionary, added support for loops and conditions * Changed differential getters to use pointer types, added getter type checking * Fixed some bugs related to diff type registration and differential getters * Removed some superfluous code * Removed some more unused code. * Fixed an issue with witness substitution * Minor fix Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ir-diff-call.cpp')
-rw-r--r--source/slang/slang-ir-diff-call.cpp43
1 files changed, 34 insertions, 9 deletions
diff --git a/source/slang/slang-ir-diff-call.cpp b/source/slang/slang-ir-diff-call.cpp
index 92044be3c..ee78246fe 100644
--- a/source/slang/slang-ir-diff-call.cpp
+++ b/source/slang/slang-ir-diff-call.cpp
@@ -52,25 +52,50 @@ struct DerivativeCallProcessContext
// the intstructions.
void processDifferentiate(IRJVPDifferentiate* derivOfInst)
{
- IRFunc* jvpFunc = nullptr;
+ IRInst* jvpCallable = nullptr;
+
+ // First get base function
+ auto origCallable = derivOfInst->getBaseFn();
+
+ IRSpecialize* specialization = nullptr;
+
+ // If the base is a specialize inst, get the inner fn.
+ if (auto origSpecialize = as<IRSpecialize>(origCallable))
+ {
+ specialization = origSpecialize;
+ origCallable = origSpecialize->getBase();
+ }
+
+ // We should have either a generic or a function reference on our hands.
+ SLANG_ASSERT(as<IRGeneric>(origCallable) || as<IRFunc>(origCallable));
// Resolve the derivative function.
//
// Check for the 'JVPDerivativeReference' decorator on the
// base function.
- if (auto jvpRefDecorator = derivOfInst->base.get()->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ if (auto jvpRefDecorator = origCallable->findDecoration<IRJVPDerivativeReferenceDecoration>())
{
- jvpFunc = jvpRefDecorator->getJVPFunc();
+ jvpCallable = jvpRefDecorator->getJVPFunc();
+ }
+
+ SLANG_ASSERT(jvpCallable);
+
+ if (specialization)
+ {
+ // Replace the specialization target with the JVP func.
+ specialization->setOperand(0, jvpCallable);
+
+ // Then replace the JVPDifferentiate inst with the specialization.
+ derivOfInst->replaceUsesWith(specialization);
}
-
- // Substitute all uses of the 'derivativeOf' operation
- // with the resolved derivative function.
- while (auto use = derivOfInst->firstUse)
+ else
{
- use->set(jvpFunc);
+ // Substitute all uses of the 'derivativeOf' operation
+ // with the resolved derivative function.
+ derivOfInst->replaceUsesWith(jvpCallable);
}
- // Remove the 'derivativeOf'
+ // Remove the 'derivativeOf' inst.
derivOfInst->removeAndDeallocate();
}
};