summaryrefslogtreecommitdiff
path: root/source/slang/slang-emit.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-emit.cpp')
-rw-r--r--source/slang/slang-emit.cpp25
1 files changed, 25 insertions, 0 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 4666e80d8..1ea54475e 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -10,6 +10,8 @@
#include "slang-ir-collect-global-uniforms.h"
#include "slang-ir-cleanup-void.h"
#include "slang-ir-dce.h"
+#include "slang-ir-diff-call.h"
+#include "slang-ir-diff-jvp.h"
#include "slang-ir-dll-export.h"
#include "slang-ir-dll-import.h"
#include "slang-ir-eliminate-phis.h"
@@ -365,6 +367,29 @@ Result linkAndOptimizeIR(
lowerReinterpret(targetRequest, irModule, sink);
validateIRModuleIfEnabled(codeGenContext, irModule);
+
+ // Inline calls to any functions marked with [__unsafeInlineEarly] again,
+ // since we may be missing out cases prevented by the functions that we just specialzied.
+ performMandatoryEarlyInlining(irModule);
+
+ dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF");
+
+ // Process higher-order calles to auto-diff passes.
+ // 1. Generate JVP code wherever necessary. (Linearization or "forward-mode" pass)
+ processJVPDerivativeMarkers(irModule, sink);
+
+ // 2. Transpose JVP to VJP code wherever needed. (Transposition or "reverse-mode" pass)
+ // processVJPDerivativeMarkers(module); // Disabled currently. No impl yet.
+
+ // 3. Fill in higher-order invocations with the generated functions.
+ processDerivativeCalls(irModule);
+
+ dumpIRIfEnabled(codeGenContext, irModule, "AFTER-AUTODIFF");
+
+ validateIRModuleIfEnabled(codeGenContext, irModule);
+
+ applySparseConditionalConstantPropagation(irModule);
+ eliminateDeadCode(irModule);
// For targets that supports dynamic dispatch, we need to lower the
// generics / interface types to ordinary functions and types using