summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-diff-call.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-diff-call.cpp')
-rw-r--r--source/slang/slang-ir-diff-call.cpp43
1 files changed, 34 insertions, 9 deletions
diff --git a/source/slang/slang-ir-diff-call.cpp b/source/slang/slang-ir-diff-call.cpp
index 92044be3c..ee78246fe 100644
--- a/source/slang/slang-ir-diff-call.cpp
+++ b/source/slang/slang-ir-diff-call.cpp
@@ -52,25 +52,50 @@ struct DerivativeCallProcessContext
// the intstructions.
void processDifferentiate(IRJVPDifferentiate* derivOfInst)
{
- IRFunc* jvpFunc = nullptr;
+ IRInst* jvpCallable = nullptr;
+
+ // First get base function
+ auto origCallable = derivOfInst->getBaseFn();
+
+ IRSpecialize* specialization = nullptr;
+
+ // If the base is a specialize inst, get the inner fn.
+ if (auto origSpecialize = as<IRSpecialize>(origCallable))
+ {
+ specialization = origSpecialize;
+ origCallable = origSpecialize->getBase();
+ }
+
+ // We should have either a generic or a function reference on our hands.
+ SLANG_ASSERT(as<IRGeneric>(origCallable) || as<IRFunc>(origCallable));
// Resolve the derivative function.
//
// Check for the 'JVPDerivativeReference' decorator on the
// base function.
- if (auto jvpRefDecorator = derivOfInst->base.get()->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ if (auto jvpRefDecorator = origCallable->findDecoration<IRJVPDerivativeReferenceDecoration>())
{
- jvpFunc = jvpRefDecorator->getJVPFunc();
+ jvpCallable = jvpRefDecorator->getJVPFunc();
+ }
+
+ SLANG_ASSERT(jvpCallable);
+
+ if (specialization)
+ {
+ // Replace the specialization target with the JVP func.
+ specialization->setOperand(0, jvpCallable);
+
+ // Then replace the JVPDifferentiate inst with the specialization.
+ derivOfInst->replaceUsesWith(specialization);
}
-
- // Substitute all uses of the 'derivativeOf' operation
- // with the resolved derivative function.
- while (auto use = derivOfInst->firstUse)
+ else
{
- use->set(jvpFunc);
+ // Substitute all uses of the 'derivativeOf' operation
+ // with the resolved derivative function.
+ derivOfInst->replaceUsesWith(jvpCallable);
}
- // Remove the 'derivativeOf'
+ // Remove the 'derivativeOf' inst.
derivOfInst->removeAndDeallocate();
}
};