summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-autodiff-cfg-norm.cpp107
-rw-r--r--source/slang/slang-ir-eliminate-multilevel-break.cpp3
-rw-r--r--source/slang/slang-ir-util.cpp107
-rw-r--r--source/slang/slang-ir-util.h2
-rw-r--r--tests/autodiff/loop-mutating-array.slang60
5 files changed, 172 insertions, 107 deletions
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp
index 08a6c84e7..30e832719 100644
--- a/source/slang/slang-ir-autodiff-cfg-norm.cpp
+++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp
@@ -707,113 +707,6 @@ struct CFGNormalizationPass
}
};
-static void legalizeDefUse(IRGlobalValueWithCode* func)
-{
- auto dom = computeDominatorTree(func);
- for (auto block : func->getBlocks())
- {
- for (auto inst : block->getModifiableChildren())
- {
- // Inspect all uses of `inst` and find the common dominator of all use sites.
- IRBlock* commonDominator = block;
- for (auto use = inst->firstUse; use; use = use->nextUse)
- {
- auto userBlock = as<IRBlock>(use->getUser()->getParent());
- if (!userBlock)
- continue;
- while (commonDominator && !dom->dominates(commonDominator, userBlock))
- {
- commonDominator = dom->getImmediateDominator(commonDominator);
- }
- }
- SLANG_ASSERT(commonDominator);
-
- if (commonDominator == block)
- continue;
-
- // If the common dominator is not `block`, it means we have detected
- // uses that is no longer dominated by the current definition, and need
- // to be fixed.
-
- // Normally, we can simply move the definition to the common dominator.
- // An exception is when the common dominator is the target block of a
- // loop. Note that after normalization, loops are in the form of:
- // ```
- // loop { if (condition) block; else break; }
- // ```
- // If we find ourselves needing to make the inst available right before
- // the `if`, it means we are seeing uses of the inst outside the loop.
- // In this case, we should insert a var/move the inst before the loop
- // instead of before the `if`. This situation can occur in the IR if
- // the original code is lowered from a `do-while` loop.
- for (auto use = commonDominator->firstUse; use; use = use->nextUse)
- {
- if (auto loopUser = as<IRLoop>(use->getUser()))
- {
- if (loopUser->getTargetBlock() == commonDominator)
- {
- bool shouldMoveToHeader = false;
- // Check that the break-block dominates any of the uses are past the break
- // block
- for (auto _use = inst->firstUse; _use; _use = _use->nextUse)
- {
- if (dom->dominates(
- loopUser->getBreakBlock(),
- _use->getUser()->getParent()))
- {
- shouldMoveToHeader = true;
- break;
- }
- }
-
- if (shouldMoveToHeader)
- commonDominator = as<IRBlock>(loopUser->getParent());
- break;
- }
- }
- }
- // Now we can legalize uses based on the type of `inst`.
- if (auto var = as<IRVar>(inst))
- {
- // If inst is an var, this is easy, we just move it to the
- // common dominator.
- var->insertBefore(commonDominator->getTerminator());
- }
- else
- {
- // For all other insts, we need to create a local var for it,
- // and replace all uses with a load from the local var.
- IRBuilder builder(func);
- builder.setInsertBefore(commonDominator->getTerminator());
- IRVar* tempVar = builder.emitVar(inst->getFullType());
- auto defaultVal = builder.emitDefaultConstruct(inst->getFullType());
- builder.emitStore(tempVar, defaultVal);
-
- builder.setInsertAfter(inst);
- builder.emitStore(tempVar, inst);
-
- traverseUses(
- inst,
- [&](IRUse* use)
- {
- auto userBlock = as<IRBlock>(use->getUser()->getParent());
- if (!userBlock)
- return;
- // Only fix the use of the current definition of `inst` does not
- // dominate it.
- if (!dom->dominates(block, userBlock))
- {
- // Replace the use with a load of tempVar.
- builder.setInsertBefore(use->getUser());
- auto load = builder.emitLoad(tempVar);
- builder.replaceOperand(use, load);
- }
- });
- }
- }
- }
-}
-
void normalizeCFG(
IRModule* module,
IRGlobalValueWithCode* func,
diff --git a/source/slang/slang-ir-eliminate-multilevel-break.cpp b/source/slang/slang-ir-eliminate-multilevel-break.cpp
index a54d1b46d..741b7fbd0 100644
--- a/source/slang/slang-ir-eliminate-multilevel-break.cpp
+++ b/source/slang/slang-ir-eliminate-multilevel-break.cpp
@@ -5,6 +5,7 @@
#include "slang-ir-dominators.h"
#include "slang-ir-eliminate-phis.h"
#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
#include "slang-ir.h"
namespace Slang
@@ -475,6 +476,8 @@ struct EliminateMultiLevelBreakContext
}
}
}
+
+ legalizeDefUse(func);
}
};
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index d05e1db7d..3043e7e31 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -1986,4 +1986,111 @@ Int getSpecializationConstantId(IRGlobalParam* param)
return offset->getOffset();
}
+void legalizeDefUse(IRGlobalValueWithCode* func)
+{
+ auto dom = computeDominatorTree(func);
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getModifiableChildren())
+ {
+ // Inspect all uses of `inst` and find the common dominator of all use sites.
+ IRBlock* commonDominator = block;
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ auto userBlock = as<IRBlock>(use->getUser()->getParent());
+ if (!userBlock)
+ continue;
+ while (commonDominator && !dom->dominates(commonDominator, userBlock))
+ {
+ commonDominator = dom->getImmediateDominator(commonDominator);
+ }
+ }
+ SLANG_ASSERT(commonDominator);
+
+ if (commonDominator == block)
+ continue;
+
+ // If the common dominator is not `block`, it means we have detected
+ // uses that is no longer dominated by the current definition, and need
+ // to be fixed.
+
+ // Normally, we can simply move the definition to the common dominator.
+ // An exception is when the common dominator is the target block of a
+ // loop. Note that after normalization, loops are in the form of:
+ // ```
+ // loop { if (condition) block; else break; }
+ // ```
+ // If we find ourselves needing to make the inst available right before
+ // the `if`, it means we are seeing uses of the inst outside the loop.
+ // In this case, we should insert a var/move the inst before the loop
+ // instead of before the `if`. This situation can occur in the IR if
+ // the original code is lowered from a `do-while` loop.
+ for (auto use = commonDominator->firstUse; use; use = use->nextUse)
+ {
+ if (auto loopUser = as<IRLoop>(use->getUser()))
+ {
+ if (loopUser->getTargetBlock() == commonDominator)
+ {
+ bool shouldMoveToHeader = false;
+ // Check that the break-block dominates any of the uses are past the break
+ // block
+ for (auto _use = inst->firstUse; _use; _use = _use->nextUse)
+ {
+ if (dom->dominates(
+ loopUser->getBreakBlock(),
+ _use->getUser()->getParent()))
+ {
+ shouldMoveToHeader = true;
+ break;
+ }
+ }
+
+ if (shouldMoveToHeader)
+ commonDominator = as<IRBlock>(loopUser->getParent());
+ break;
+ }
+ }
+ }
+ // Now we can legalize uses based on the type of `inst`.
+ if (auto var = as<IRVar>(inst))
+ {
+ // If inst is an var, this is easy, we just move it to the
+ // common dominator.
+ var->insertBefore(commonDominator->getTerminator());
+ }
+ else
+ {
+ // For all other insts, we need to create a local var for it,
+ // and replace all uses with a load from the local var.
+ IRBuilder builder(func);
+ builder.setInsertBefore(commonDominator->getTerminator());
+ IRVar* tempVar = builder.emitVar(inst->getFullType());
+ auto defaultVal = builder.emitDefaultConstruct(inst->getFullType());
+ builder.emitStore(tempVar, defaultVal);
+
+ builder.setInsertAfter(inst);
+ builder.emitStore(tempVar, inst);
+
+ traverseUses(
+ inst,
+ [&](IRUse* use)
+ {
+ auto userBlock = as<IRBlock>(use->getUser()->getParent());
+ if (!userBlock)
+ return;
+ // Only fix the use of the current definition of `inst` does not
+ // dominate it.
+ if (!dom->dominates(block, userBlock))
+ {
+ // Replace the use with a load of tempVar.
+ builder.setInsertBefore(use->getUser());
+ auto load = builder.emitLoad(tempVar);
+ builder.replaceOperand(use, load);
+ }
+ });
+ }
+ }
+ }
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 666ac71c0..46e0105f5 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -375,6 +375,8 @@ IRType* getIRVectorBaseType(IRType* type);
Int getSpecializationConstantId(IRGlobalParam* param);
+void legalizeDefUse(IRGlobalValueWithCode* func);
+
} // namespace Slang
#endif
diff --git a/tests/autodiff/loop-mutating-array.slang b/tests/autodiff/loop-mutating-array.slang
new file mode 100644
index 000000000..0eada5153
--- /dev/null
+++ b/tests/autodiff/loop-mutating-array.slang
@@ -0,0 +1,60 @@
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float> dpfloat;
+typedef float.Differential dfloat;
+
+
+struct SpatialVertex : IDifferentiable
+{
+ float x;
+};
+
+struct MaterialVertex
+{
+ float x;
+};
+
+//TEST_INPUT:ubuffer(data=[2.0 2.0 2.0 2.0 2.0], stride=4):name=pathVertices
+RWStructuredBuffer<MaterialVertex> pathVertices;
+
+[Differentiable]
+SpatialVertex transform(float p, MaterialVertex m)
+{
+ return { p * m.x };
+}
+
+[Differentiable]
+float test_simple_loop(float y)
+{
+ SpatialVertex vShade[2];
+ int pathLength = 1;
+
+ [ForceUnroll]
+ for (int i = 0; i < 2; i++)
+ {
+ if (!(pathVertices[i].x > 1.4))
+ {
+ pathLength = i;
+ break;
+ }
+
+ vShade[i] = transform(y, pathVertices[i]);
+ }
+
+ return vShade[0].x + vShade[1].x;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ dpfloat dpy = dpfloat(1.0, 1.0);
+
+ var dpresult = fwd_diff(test_simple_loop)(dpy);
+ outputBuffer[0] = pathVertices[0].x; // CHECK: 2.0
+ outputBuffer[1] = dpresult.d; // CHECK: 4.0
+ }
+}