summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-diff-call.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-10-26 22:21:29 -0400
committerGitHub <noreply@github.com>2022-10-26 19:21:29 -0700
commitf7f0dcadd3b2aca4c0bcd03a96e11c617cf69fc2 (patch)
tree574dff2bcb8c5a3de9e74d18346a424c82d62a7a /source/slang/slang-ir-diff-call.cpp
parent939be44ca23476e622dfb24a592383fe2a1da61f (diff)
Adding a differentiable standard library (#2465)
Diffstat (limited to 'source/slang/slang-ir-diff-call.cpp')
-rw-r--r--source/slang/slang-ir-diff-call.cpp35
1 files changed, 12 insertions, 23 deletions
diff --git a/source/slang/slang-ir-diff-call.cpp b/source/slang/slang-ir-diff-call.cpp
index ee78246fe..3f2c6c789 100644
--- a/source/slang/slang-ir-diff-call.cpp
+++ b/source/slang/slang-ir-diff-call.cpp
@@ -57,22 +57,22 @@ struct DerivativeCallProcessContext
// First get base function
auto origCallable = derivOfInst->getBaseFn();
- IRSpecialize* specialization = nullptr;
-
- // If the base is a specialize inst, get the inner fn.
+ // Resolve the derivative function for IRJVPDifferentiate(IRSpecialize(IRFunc))
+ // Check the specialize inst for a reference to the derivative fn.
+ //
if (auto origSpecialize = as<IRSpecialize>(origCallable))
{
- specialization = origSpecialize;
- origCallable = origSpecialize->getBase();
+ if (auto jvpSpecRefDecorator = origSpecialize->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ {
+ jvpCallable = jvpSpecRefDecorator->getJVPFunc();
+ }
}
- // We should have either a generic or a function reference on our hands.
- SLANG_ASSERT(as<IRGeneric>(origCallable) || as<IRFunc>(origCallable));
-
- // Resolve the derivative function.
+ // Resolve the derivative function for an IRJVPDifferentiate(IRFunc)
//
// Check for the 'JVPDerivativeReference' decorator on the
// base function.
+ //
if (auto jvpRefDecorator = origCallable->findDecoration<IRJVPDerivativeReferenceDecoration>())
{
jvpCallable = jvpRefDecorator->getJVPFunc();
@@ -80,20 +80,9 @@ struct DerivativeCallProcessContext
SLANG_ASSERT(jvpCallable);
- if (specialization)
- {
- // Replace the specialization target with the JVP func.
- specialization->setOperand(0, jvpCallable);
-
- // Then replace the JVPDifferentiate inst with the specialization.
- derivOfInst->replaceUsesWith(specialization);
- }
- else
- {
- // Substitute all uses of the 'derivativeOf' operation
- // with the resolved derivative function.
- derivOfInst->replaceUsesWith(jvpCallable);
- }
+ // Substitute all uses of the 'derivativeOf' operation
+ // with the resolved derivative function.
+ derivOfInst->replaceUsesWith(jvpCallable);
// Remove the 'derivativeOf' inst.
derivOfInst->removeAndDeallocate();