diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-10-26 22:21:29 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-26 19:21:29 -0700 |
| commit | f7f0dcadd3b2aca4c0bcd03a96e11c617cf69fc2 (patch) | |
| tree | 574dff2bcb8c5a3de9e74d18346a424c82d62a7a /source/slang/slang-ir-diff-call.cpp | |
| parent | 939be44ca23476e622dfb24a592383fe2a1da61f (diff) | |
Adding a differentiable standard library (#2465)
Diffstat (limited to 'source/slang/slang-ir-diff-call.cpp')
| -rw-r--r-- | source/slang/slang-ir-diff-call.cpp | 35 |
1 files changed, 12 insertions, 23 deletions
diff --git a/source/slang/slang-ir-diff-call.cpp b/source/slang/slang-ir-diff-call.cpp index ee78246fe..3f2c6c789 100644 --- a/source/slang/slang-ir-diff-call.cpp +++ b/source/slang/slang-ir-diff-call.cpp @@ -57,22 +57,22 @@ struct DerivativeCallProcessContext // First get base function auto origCallable = derivOfInst->getBaseFn(); - IRSpecialize* specialization = nullptr; - - // If the base is a specialize inst, get the inner fn. + // Resolve the derivative function for IRJVPDifferentiate(IRSpecialize(IRFunc)) + // Check the specialize inst for a reference to the derivative fn. + // if (auto origSpecialize = as<IRSpecialize>(origCallable)) { - specialization = origSpecialize; - origCallable = origSpecialize->getBase(); + if (auto jvpSpecRefDecorator = origSpecialize->findDecoration<IRJVPDerivativeReferenceDecoration>()) + { + jvpCallable = jvpSpecRefDecorator->getJVPFunc(); + } } - // We should have either a generic or a function reference on our hands. - SLANG_ASSERT(as<IRGeneric>(origCallable) || as<IRFunc>(origCallable)); - - // Resolve the derivative function. + // Resolve the derivative function for an IRJVPDifferentiate(IRFunc) // // Check for the 'JVPDerivativeReference' decorator on the // base function. + // if (auto jvpRefDecorator = origCallable->findDecoration<IRJVPDerivativeReferenceDecoration>()) { jvpCallable = jvpRefDecorator->getJVPFunc(); @@ -80,20 +80,9 @@ struct DerivativeCallProcessContext SLANG_ASSERT(jvpCallable); - if (specialization) - { - // Replace the specialization target with the JVP func. - specialization->setOperand(0, jvpCallable); - - // Then replace the JVPDifferentiate inst with the specialization. - derivOfInst->replaceUsesWith(specialization); - } - else - { - // Substitute all uses of the 'derivativeOf' operation - // with the resolved derivative function. - derivOfInst->replaceUsesWith(jvpCallable); - } + // Substitute all uses of the 'derivativeOf' operation + // with the resolved derivative function. + derivOfInst->replaceUsesWith(jvpCallable); // Remove the 'derivativeOf' inst. derivOfInst->removeAndDeallocate(); |
