diff options
Diffstat (limited to 'source/slang/slang-ir-diff-call.cpp')
| -rw-r--r-- | source/slang/slang-ir-diff-call.cpp | 43 |
1 files changed, 34 insertions, 9 deletions
diff --git a/source/slang/slang-ir-diff-call.cpp b/source/slang/slang-ir-diff-call.cpp index 92044be3c..ee78246fe 100644 --- a/source/slang/slang-ir-diff-call.cpp +++ b/source/slang/slang-ir-diff-call.cpp @@ -52,25 +52,50 @@ struct DerivativeCallProcessContext // the intstructions. void processDifferentiate(IRJVPDifferentiate* derivOfInst) { - IRFunc* jvpFunc = nullptr; + IRInst* jvpCallable = nullptr; + + // First get base function + auto origCallable = derivOfInst->getBaseFn(); + + IRSpecialize* specialization = nullptr; + + // If the base is a specialize inst, get the inner fn. + if (auto origSpecialize = as<IRSpecialize>(origCallable)) + { + specialization = origSpecialize; + origCallable = origSpecialize->getBase(); + } + + // We should have either a generic or a function reference on our hands. + SLANG_ASSERT(as<IRGeneric>(origCallable) || as<IRFunc>(origCallable)); // Resolve the derivative function. // // Check for the 'JVPDerivativeReference' decorator on the // base function. - if (auto jvpRefDecorator = derivOfInst->base.get()->findDecoration<IRJVPDerivativeReferenceDecoration>()) + if (auto jvpRefDecorator = origCallable->findDecoration<IRJVPDerivativeReferenceDecoration>()) { - jvpFunc = jvpRefDecorator->getJVPFunc(); + jvpCallable = jvpRefDecorator->getJVPFunc(); + } + + SLANG_ASSERT(jvpCallable); + + if (specialization) + { + // Replace the specialization target with the JVP func. + specialization->setOperand(0, jvpCallable); + + // Then replace the JVPDifferentiate inst with the specialization. + derivOfInst->replaceUsesWith(specialization); } - - // Substitute all uses of the 'derivativeOf' operation - // with the resolved derivative function. - while (auto use = derivOfInst->firstUse) + else { - use->set(jvpFunc); + // Substitute all uses of the 'derivativeOf' operation + // with the resolved derivative function. + derivOfInst->replaceUsesWith(jvpCallable); } - // Remove the 'derivativeOf' + // Remove the 'derivativeOf' inst. derivOfInst->removeAndDeallocate(); } }; |
