summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-diff-call.cpp
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2024-10-29 14:49:26 +0800
committerGitHub <noreply@github.com>2024-10-29 14:49:26 +0800
commitf65d756bff8d4c5cbc15bd0322a2ae8e6b896a21 (patch)
treeea1d61342cd29368e19135000ec2948813096205 /source/slang/slang-ir-diff-call.cpp
parenta729c15e9dce9f5116a38afc66329ab2ca4cea54 (diff)
format
* format * Minor test fixes * enable checking cpp format in ci
Diffstat (limited to 'source/slang/slang-ir-diff-call.cpp')
-rw-r--r--source/slang/slang-ir-diff-call.cpp35
1 files changed, 17 insertions, 18 deletions
diff --git a/source/slang/slang-ir-diff-call.cpp b/source/slang/slang-ir-diff-call.cpp
index a574d6b7e..1aeaa3182 100644
--- a/source/slang/slang-ir-diff-call.cpp
+++ b/source/slang/slang-ir-diff-call.cpp
@@ -1,8 +1,8 @@
// slang-ir-diff-call.cpp
#include "slang-ir-diff-call.h"
-#include "slang-ir.h"
#include "slang-ir-insts.h"
+#include "slang-ir.h"
namespace Slang
{
@@ -10,14 +10,14 @@ namespace Slang
struct DerivativeCallProcessContext
{
// This type passes over the module and replaces
- // derivative calls with the processed derivative
+ // derivative calls with the processed derivative
// function.
//
- IRModule* module;
+ IRModule* module;
bool processModule()
{
- // Run through all the global-level instructions,
+ // Run through all the global-level instructions,
// looking for callable blocks.
for (auto inst : module->getGlobalInsts())
{
@@ -29,9 +29,10 @@ struct DerivativeCallProcessContext
{
// Iterate over each child instruction.
auto child = block->getFirstInst();
- if (!child) continue;
+ if (!child)
+ continue;
- do
+ do
{
auto nextChild = child->getNextInst();
// Look for IRForwardDifferentiate
@@ -40,29 +41,29 @@ struct DerivativeCallProcessContext
processDifferentiate(derivOf);
}
child = nextChild;
- }
- while (child);
+ } while (child);
}
}
}
return true;
}
- // Perform forward-mode automatic differentiation on
+ // Perform forward-mode automatic differentiation on
// the intstructions.
void processDifferentiate(IRForwardDifferentiate* derivOfInst)
{
IRInst* jvpCallable = nullptr;
- // First get base function
+ // First get base function
auto origCallable = derivOfInst->getBaseFn();
// Resolve the derivative function for IRForwardDifferentiate(IRSpecialize(IRFunc))
// Check the specialize inst for a reference to the derivative fn.
- //
+ //
if (auto origSpecialize = as<IRSpecialize>(origCallable))
{
- if (auto jvpSpecRefDecorator = origSpecialize->findDecoration<IRForwardDerivativeDecoration>())
+ if (auto jvpSpecRefDecorator =
+ origSpecialize->findDecoration<IRForwardDerivativeDecoration>())
{
jvpCallable = jvpSpecRefDecorator->getForwardDerivativeFunc();
}
@@ -80,7 +81,7 @@ struct DerivativeCallProcessContext
SLANG_ASSERT(jvpCallable);
- // Substitute all uses of the 'derivativeOf' operation
+ // Substitute all uses of the 'derivativeOf' operation
// with the resolved derivative function.
derivOfInst->replaceUsesWith(jvpCallable);
@@ -90,10 +91,8 @@ struct DerivativeCallProcessContext
};
// Set up context and call main process method.
-//
-bool processDerivativeCalls(
- IRModule* module,
- IRDerivativeCallProcessOptions const&)
+//
+bool processDerivativeCalls(IRModule* module, IRDerivativeCallProcessOptions const&)
{
DerivativeCallProcessContext context;
context.module = module;
@@ -101,4 +100,4 @@ bool processDerivativeCalls(
return context.processModule();
}
-}
+} // namespace Slang