summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-09-20 20:54:10 -0400
committerGitHub <noreply@github.com>2023-09-20 17:54:10 -0700
commit29c318bfe5c66350a67467e3b6ef08120f00fb7e (patch)
tree77e57e12a9f2e99797b4612e2ff1f64fb483c9c8 /source
parent5b23870eb0d3c0f1545304f67d15cffc16830107 (diff)
Move force inlining step to before `processAutodiffCalls` (and run in loop) (#3217)
* Move auto-diff force inlining step to before `processAutodiffCalls` * Fix `replaceUsesWith` to handle existing inst defined after current use. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-emit.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.cpp6
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp2
-rw-r--r--source/slang/slang-ir-inline.cpp6
-rw-r--r--source/slang/slang-ir-inline.h3
-rw-r--r--source/slang/slang-ir.cpp55
7 files changed, 72 insertions, 4 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 98c9c9803..3a0ee36be 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -413,6 +413,8 @@ Result linkAndOptimizeIR(
// since we may be missing out cases prevented by the functions that we just specialzied.
performMandatoryEarlyInlining(irModule);
+ performPreAutoDiffForceInlining(irModule);
+
// Unroll loops.
if (codeGenContext->getSink()->getErrorCount() == 0)
{
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index b82818b99..c906f93eb 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1690,8 +1690,6 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func)
{
insertTempVarForMutableParams(autoDiffSharedContext->moduleInst->getModule(), func);
removeLinkageDecorations(func);
-
- performPreAutoDiffForceInlining(func);
initializeLocalVariables(autoDiffSharedContext->moduleInst->getModule(), func);
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
index ebf7a9484..61baa7dd7 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.cpp
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -969,7 +969,10 @@ void applyCheckpointSet(
bool isInverted = checkpointInfo->invertSet.contains(param);
bool loopInductionInfo = checkpointInfo->loopInductionInfo.tryGetValue(param);
if (!isRecomputed && !isInverted)
+ {
+ ii++;
continue;
+ }
if (!loopInductionInfo)
{
@@ -982,7 +985,10 @@ void applyCheckpointSet(
applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, blockIndexInfo, param);
if (loopInductionInfo)
+ {
+ ii++;
continue;
+ }
// Copy primal branch-arg for predecessor blocks.
HashSet<IRBlock*> predecessorSet;
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 335b6572e..7d5659425 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -533,8 +533,6 @@ namespace Slang
{
removeLinkageDecorations(func);
- performPreAutoDiffForceInlining(func);
-
DifferentiableTypeConformanceContext diffTypeContext(autoDiffSharedContext);
diffTypeContext.setFunc(func);
diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp
index 06b63db52..1d308c507 100644
--- a/source/slang/slang-ir-inline.cpp
+++ b/source/slang/slang-ir-inline.cpp
@@ -872,6 +872,12 @@ bool performPreAutoDiffForceInlining(IRGlobalValueWithCode* func)
return pass.considerAllCallSitesRec(func);
}
+bool performPreAutoDiffForceInlining(IRModule* module)
+{
+ PreAutoDiffForceInliningPass pass(module);
+ return pass.considerAllCallSitesRec(module->getModuleInst());
+}
+
// Defined in slang-ir-specialize-resource.cpp
bool isResourceType(IRType* type);
bool isIllegalGLSLParameterType(IRType* type);
diff --git a/source/slang/slang-ir-inline.h b/source/slang/slang-ir-inline.h
index fe050b7b9..539bb26c0 100644
--- a/source/slang/slang-ir-inline.h
+++ b/source/slang/slang-ir-inline.h
@@ -25,6 +25,9 @@ namespace Slang
/// Perform force inlining of functions that does not have custom derivatives.
bool performPreAutoDiffForceInlining(IRGlobalValueWithCode* func);
+ /// Perform force inlining of all functions in a module that does not have custom derivatives.
+ bool performPreAutoDiffForceInlining(IRModule* module);
+
/// Inline calls to functions that returns a resource/sampler via either return value or output parameter.
void performGLSLResourceReturnFunctionInlining(IRModule* module);
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index a54bc1f2e..b4a8e6f42 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -7177,6 +7177,57 @@ namespace Slang
void validateIRInstOperands(IRInst*);
+
+ // Returns true if `instToCheck` is defined after `otherInst`.
+ static bool _isInstDefinedAfter(IRInst* instToCheck, IRInst* otherInst)
+ {
+ for (auto inst = otherInst->getNextInst(); inst; inst = inst->getNextInst())
+ {
+ if (inst == instToCheck)
+ return true;
+ }
+ return false;
+ }
+
+ static void _maybeHoistOperand(IRUse* use)
+ {
+ ShortList<IRUse*, 16> workList1, workList2;
+ workList1.add(use);
+ while (workList1.getCount())
+ {
+ for (auto item : workList1)
+ {
+ auto user = item->getUser();
+ auto operand = item->get();
+ if (!operand)
+ continue;
+
+ if (!getIROpInfo(operand->getOp()).isHoistable())
+ continue;
+
+ // We can't handle the case where operand and user are in different blocks.
+ if (operand->getParent() != user->getParent())
+ continue;
+
+ // We allow out-of-order uses in global scope.
+ if (operand->getParent() && operand->getParent()->getOp() == kIROp_Module)
+ continue;
+
+ // If the operand is defined after user, move it to before user.
+ if (_isInstDefinedAfter(operand, user))
+ {
+ operand->insertBefore(user);
+ for (UInt i = 0; i < operand->getOperandCount(); i++)
+ {
+ workList2.add(operand->getOperands() + i);
+ }
+ workList2.add(&operand->typeUse);
+ }
+ }
+ workList1 = _Move(workList2);
+ }
+ }
+
static void _replaceInstUsesWith(IRInst* thisInst, IRInst* other)
{
IRDeduplicationContext* dedupContext = nullptr;
@@ -7259,6 +7310,10 @@ namespace Slang
// Swap this use over to use the other value.
uu->usedValue = other;
+ // If `other` is hoistable, then we need to make sure `other` is hoisted
+ // to a point before `user`, if it is not already so.
+ _maybeHoistOperand(uu);
+
if (userIsHoistable)
{
// Is the updated inst already exists in the global numbering map?