summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-diff-call.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-diff-call.cpp')
-rw-r--r--source/slang/slang-ir-diff-call.cpp90
1 files changed, 90 insertions, 0 deletions
diff --git a/source/slang/slang-ir-diff-call.cpp b/source/slang/slang-ir-diff-call.cpp
new file mode 100644
index 000000000..76ffe3c8b
--- /dev/null
+++ b/source/slang/slang-ir-diff-call.cpp
@@ -0,0 +1,90 @@
+// slang-ir-diff-call.cpp
+#include "slang-ir-diff-call.h"
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+
+namespace Slang
+{
+
+struct DerivativeCallProcessContext
+{
+ // This type passes over the module and replaces
+ // derivative calls with the processed derivative
+ // function.
+ //
+ IRModule* module;
+
+ bool processModule()
+ {
+ // Run through all the global-level instructions,
+ // looking for callable blocks.
+ for (auto inst : module->getGlobalInsts())
+ {
+ // If the instr is a callable, get all the basic blocks
+ if (auto callable = as<IRGlobalValueWithCode>(inst))
+ {
+ // Iterate over each block in the callable
+ for (auto block : callable->getBlocks())
+ {
+ // Iterate over each child instruction.
+ auto child = block->getFirstInst();
+ if (!child) continue;
+
+ do
+ {
+ auto nextChild = child->getNextInst();
+ // Look for IRJVPDerivativeOf
+ if (auto derivOf = as<IRJVPDerivativeOf>(child))
+ {
+ processDerivativeOf(derivOf);
+ }
+ child = nextChild;
+ }
+ while (child);
+ }
+ }
+ }
+ return true;
+ }
+
+ // Perform forward-mode automatic differentiation on
+ // the intstructions.
+ void processDerivativeOf(IRJVPDerivativeOf* derivOfInst)
+ {
+ IRFunc* jvpFunc = nullptr;
+
+ // Resolve the derivative function.
+ //
+ // Check for the 'JVPDerivativeReference' decorator on the
+ // base function.
+ if (auto jvpRefDecorator = derivOfInst->base.get()->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ {
+ jvpFunc = jvpRefDecorator->getJVPFunc();
+ }
+
+ // Substitute all uses of the 'derivativeOf' operation
+ // with the resolved derivative function.
+ while (auto use = derivOfInst->firstUse)
+ {
+ use->set(jvpFunc);
+ }
+
+ // Remove the 'derivativeOf'
+ derivOfInst->removeAndDeallocate();
+ }
+};
+
+// Set up context and call main process method.
+//
+bool processDerivativeCalls(
+ IRModule* module,
+ IRDerivativeCallProcessOptions const&)
+{
+ DerivativeCallProcessContext context;
+ context.module = module;
+
+ return context.processModule();
+}
+
+}