diff options
| author | Ellie Hermaszewska <ellieh@nvidia.com> | 2024-10-29 14:49:26 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-10-29 14:49:26 +0800 |
| commit | f65d756bff8d4c5cbc15bd0322a2ae8e6b896a21 (patch) | |
| tree | ea1d61342cd29368e19135000ec2948813096205 /source/slang/slang-ir-diff-call.cpp | |
| parent | a729c15e9dce9f5116a38afc66329ab2ca4cea54 (diff) | |
format
* format
* Minor test fixes
* enable checking cpp format in ci
Diffstat (limited to 'source/slang/slang-ir-diff-call.cpp')
| -rw-r--r-- | source/slang/slang-ir-diff-call.cpp | 35 |
1 files changed, 17 insertions, 18 deletions
diff --git a/source/slang/slang-ir-diff-call.cpp b/source/slang/slang-ir-diff-call.cpp index a574d6b7e..1aeaa3182 100644 --- a/source/slang/slang-ir-diff-call.cpp +++ b/source/slang/slang-ir-diff-call.cpp @@ -1,8 +1,8 @@ // slang-ir-diff-call.cpp #include "slang-ir-diff-call.h" -#include "slang-ir.h" #include "slang-ir-insts.h" +#include "slang-ir.h" namespace Slang { @@ -10,14 +10,14 @@ namespace Slang struct DerivativeCallProcessContext { // This type passes over the module and replaces - // derivative calls with the processed derivative + // derivative calls with the processed derivative // function. // - IRModule* module; + IRModule* module; bool processModule() { - // Run through all the global-level instructions, + // Run through all the global-level instructions, // looking for callable blocks. for (auto inst : module->getGlobalInsts()) { @@ -29,9 +29,10 @@ struct DerivativeCallProcessContext { // Iterate over each child instruction. auto child = block->getFirstInst(); - if (!child) continue; + if (!child) + continue; - do + do { auto nextChild = child->getNextInst(); // Look for IRForwardDifferentiate @@ -40,29 +41,29 @@ struct DerivativeCallProcessContext processDifferentiate(derivOf); } child = nextChild; - } - while (child); + } while (child); } } } return true; } - // Perform forward-mode automatic differentiation on + // Perform forward-mode automatic differentiation on // the intstructions. void processDifferentiate(IRForwardDifferentiate* derivOfInst) { IRInst* jvpCallable = nullptr; - // First get base function + // First get base function auto origCallable = derivOfInst->getBaseFn(); // Resolve the derivative function for IRForwardDifferentiate(IRSpecialize(IRFunc)) // Check the specialize inst for a reference to the derivative fn. - // + // if (auto origSpecialize = as<IRSpecialize>(origCallable)) { - if (auto jvpSpecRefDecorator = origSpecialize->findDecoration<IRForwardDerivativeDecoration>()) + if (auto jvpSpecRefDecorator = + origSpecialize->findDecoration<IRForwardDerivativeDecoration>()) { jvpCallable = jvpSpecRefDecorator->getForwardDerivativeFunc(); } @@ -80,7 +81,7 @@ struct DerivativeCallProcessContext SLANG_ASSERT(jvpCallable); - // Substitute all uses of the 'derivativeOf' operation + // Substitute all uses of the 'derivativeOf' operation // with the resolved derivative function. derivOfInst->replaceUsesWith(jvpCallable); @@ -90,10 +91,8 @@ struct DerivativeCallProcessContext }; // Set up context and call main process method. -// -bool processDerivativeCalls( - IRModule* module, - IRDerivativeCallProcessOptions const&) +// +bool processDerivativeCalls(IRModule* module, IRDerivativeCallProcessOptions const&) { DerivativeCallProcessContext context; context.module = module; @@ -101,4 +100,4 @@ bool processDerivativeCalls( return context.processModule(); } -} +} // namespace Slang |
