summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-01-17 20:21:01 -0500
committerGitHub <noreply@github.com>2023-01-17 17:21:01 -0800
commit1a486813ef0bc7f7a2eb6eaeec2921fd71a2bd05 (patch)
tree5d64c73b7859a1657f71af95da7bc9e78fc58bf2 /source
parent2c437498d3a09b58de17a8865242814d9ea92fde (diff)
Added switch-case support; fixed non-diff parameter transposition (#2596)
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp44
-rw-r--r--source/slang/slang-ir-autodiff-fwd.h2
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp20
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h64
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h36
5 files changed, 154 insertions, 12 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index fa3eb463e..e8fe2beac 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -784,6 +784,47 @@ InstPair ForwardDiffTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* orig
return InstPair(diffLoop, diffLoop);
}
+InstPair ForwardDiffTranscriber::transcribeSwitch(IRBuilder* builder, IRSwitch* origSwitch)
+{
+ // Transcribe condition (primal only, conditions do not produce differentials)
+ auto primalCondition = findOrTranscribePrimalInst(builder, origSwitch->getCondition());
+ SLANG_ASSERT(primalCondition);
+
+ // Transcribe 'default' block
+ IRBlock* diffDefaultBlock = as<IRBlock>(
+ findOrTranscribeDiffInst(builder, origSwitch->getDefaultLabel()));
+ SLANG_ASSERT(diffDefaultBlock);
+
+ // Transcribe 'default' block
+ IRBlock* diffBreakBlock = as<IRBlock>(
+ findOrTranscribeDiffInst(builder, origSwitch->getBreakLabel()));
+ SLANG_ASSERT(diffBreakBlock);
+
+ // Transcribe all other operands
+ List<IRInst*> diffCaseValuesAndLabels;
+ for (UIndex ii = 0; ii < origSwitch->getCaseCount(); ii ++)
+ {
+ auto primalCaseValue = findOrTranscribePrimalInst(builder, origSwitch->getCaseValue(ii));
+ SLANG_ASSERT(primalCaseValue);
+
+ auto diffCaseBlock = findOrTranscribeDiffInst(builder, origSwitch->getCaseLabel(ii));
+ SLANG_ASSERT(diffCaseBlock);
+
+ diffCaseValuesAndLabels.add(primalCaseValue);
+ diffCaseValuesAndLabels.add(diffCaseBlock);
+ }
+
+ auto diffSwitchInst = builder->emitSwitch(
+ primalCondition,
+ diffBreakBlock,
+ diffDefaultBlock,
+ diffCaseValuesAndLabels.getCount(),
+ diffCaseValuesAndLabels.getBuffer());
+ builder->markInstAsMixedDifferential(diffSwitchInst);
+
+ return InstPair(diffSwitchInst, diffSwitchInst);
+}
+
InstPair ForwardDiffTranscriber::transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse)
{
// IfElse Statements come with 4 blocks. We transcribe each block into it's
@@ -1123,6 +1164,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_ifElse:
return transcribeIfElse(builder, as<IRIfElse>(origInst));
+
+ case kIROp_Switch:
+ return transcribeSwitch(builder, as<IRSwitch>(origInst));
case kIROp_MakeDifferentialPair:
return transcribeMakeDifferentialPair(builder, as<IRMakeDifferentialPair>(origInst));
diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h
index 828916c01..b09b57974 100644
--- a/source/slang/slang-ir-autodiff-fwd.h
+++ b/source/slang/slang-ir-autodiff-fwd.h
@@ -65,6 +65,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase
InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse);
+ InstPair transcribeSwitch(IRBuilder* builder, IRSwitch* origSwitch);
+
InstPair transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst);
InstPair transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst);
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 8d6419cf2..de4fbe182 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -666,14 +666,15 @@ namespace Slang
builder->setInsertInto(fwdDiffParameterBlock);
- // 1. Turn fwd-diff versions of the parameters into reverse-diff versions by wrapping them as InOutType<>
- for (auto child = fwdDiffParameterBlock->getFirstParam(); child;)
+ List<IRParam*> fwdParams;
+ for (auto child = fwdDiffParameterBlock->getFirstParam(); child; child = child->getNextParam())
{
- IRParam* nextChild = child->getNextParam();
+ fwdParams.add(child);
+ }
- auto fwdParam = as<IRParam>(child);
- SLANG_ASSERT(fwdParam);
-
+ // 1. Turn fwd-diff versions of the parameters into reverse-diff versions by wrapping them as InOutType<>
+ for (auto fwdParam : fwdParams)
+ {
// TODO: Handle ptr<pair> types.
if (auto diffPairType = as<IRDifferentialPairType>(fwdParam->getDataType()))
{
@@ -690,10 +691,11 @@ namespace Slang
else
{
// Default case (parameter has nothing to do with differentiation)
- // Do nothing.
+ // Simply move the parameter to the end.
+ //
+ fwdParam->removeFromParent();
+ fwdDiffParameterBlock->addParam(fwdParam);
}
-
- child = nextChild;
}
auto paramCount = as<IRFuncType>(diffFunc->getDataType())->getParamCount();
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index cbdb0a998..ae1a5dd70 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -858,7 +858,8 @@ struct DiffTransposePass
// Check for predecessors count.
for (auto predecessor : fwdBlockInst->getPredecessors())
{
- fwdPredecesorBlocks.add(predecessor);
+ if (!fwdPredecesorBlocks.contains(predecessor))
+ fwdPredecesorBlocks.add(predecessor);
}
SLANG_ASSERT(fwdPredecesorBlocks.getCount() > 0);
@@ -915,7 +916,7 @@ struct DiffTransposePass
{
IRBlock* revBlock = revBlockMap[fwdBlockInst];
- // If we already have a terminator, we've resolved it during
+ // If we already have a terminator, we've probably resolved it during
// tryEmitTerminator()
//
if (revBlock->getTerminator() != nullptr)
@@ -960,7 +961,63 @@ struct DiffTransposePass
}
builder->emitIfElse(condition, revTrueBlock, revFalseBlock, revAfterBlock);
- break;
+ return true;
+ }
+ case kIROp_Switch:
+ {
+ auto switchInst = as<IRSwitch>(terminatorInst);
+
+ auto condition = switchInst->getCondition();
+ SLANG_ASSERT(!isDifferentialInst(condition));
+
+ // fwd origin block is the reverse 'break' block.
+ auto revAfterBlock = as<IRBlock>(
+ revBlockMap[as<IRBlock>(switchInst->getParent())]);
+
+ // Find regions for every branch, and find the reverse-mode
+ // version of the each exit block.
+ Region* defaultRegion = regionMap[switchInst->getDefaultLabel()];
+ IRBlock* revDefaultBlock = revBlockMap[defaultRegion->exitBlock];
+
+ List<IRBlock*> revCaseBlocks;
+ for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii ++)
+ {
+ Region* caseRegion = regionMap[switchInst->getCaseLabel(ii)];
+ IRBlock* revCaseBlock = revBlockMap[caseRegion->exitBlock];
+ revCaseBlocks.add(revCaseBlock);
+ }
+
+ // If we have phi derivatives to pass on,
+ // we need to add dummy blocks to pass them using
+ // an unconditional branch.
+ //
+ if (phiParamGrads.getCount() > 0)
+ {
+ revDefaultBlock = insertPhiBlockBefore(revDefaultBlock, phiParamGrads);
+ revDefaultBlock->insertAfter(revBlock);
+
+ for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii ++)
+ {
+ revCaseBlocks[ii] = insertPhiBlockBefore(revCaseBlocks[ii], phiParamGrads);
+ revCaseBlocks[ii]->insertAfter(revBlock);
+ }
+ }
+
+ List<IRInst*> revCaseArgs;
+ for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii ++)
+ {
+ revCaseArgs.add(switchInst->getCaseValue(ii));
+ revCaseArgs.add(revCaseBlocks[ii]);
+ }
+
+ builder->emitSwitch(
+ condition,
+ revAfterBlock,
+ revDefaultBlock,
+ revCaseArgs.getCount(),
+ revCaseArgs.getBuffer());
+
+ return true;
}
default:
SLANG_UNIMPLEMENTED_X("Unhandled control flow inst during transposition");
@@ -1010,6 +1067,7 @@ struct DiffTransposePass
case kIROp_conditionalBranch:
case kIROp_ifElse:
case kIROp_loop:
+ case kIROp_Switch:
{
// Ignore. transposeBlock() should take care of adding the
// appropriate branch instruction.
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index b06ed29bf..e616578c1 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -425,6 +425,40 @@ struct DiffUnzipPass
as<IRBlock>(diffMap[afterBlock])));
}
+ case kIROp_Switch:
+ {
+ auto switchInst = as<IRSwitch>(branchInst);
+ auto breakBlock = switchInst->getBreakLabel();
+ auto defaultBlock = switchInst->getDefaultLabel();
+ auto condInst = switchInst->getCondition();
+
+ List<IRInst*> primalCaseArgs;
+ List<IRInst*> diffCaseArgs;
+
+ for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii ++)
+ {
+ primalCaseArgs.add(switchInst->getCaseValue(ii));
+ diffCaseArgs.add(switchInst->getCaseValue(ii));
+
+ primalCaseArgs.add(primalMap[switchInst->getCaseLabel(ii)]);
+ diffCaseArgs.add(diffMap[switchInst->getCaseLabel(ii)]);
+ }
+
+ return InstPair(
+ primalBuilder->emitSwitch(
+ condInst,
+ as<IRBlock>(primalMap[breakBlock]),
+ as<IRBlock>(primalMap[defaultBlock]),
+ primalCaseArgs.getCount(),
+ primalCaseArgs.getBuffer()),
+ diffBuilder->emitSwitch(
+ condInst,
+ as<IRBlock>(diffMap[breakBlock]),
+ as<IRBlock>(diffMap[defaultBlock]),
+ diffCaseArgs.getCount(),
+ diffCaseArgs.getBuffer()));
+ }
+
default:
SLANG_UNEXPECTED("Unhandled instruction");
}
@@ -452,6 +486,8 @@ struct DiffUnzipPass
case kIROp_unconditionalBranch:
case kIROp_conditionalBranch:
case kIROp_ifElse:
+ case kIROp_Switch:
+ case kIROp_loop:
return splitControlFlow(primalBuilder, diffBuilder, inst);
case kIROp_Unreachable: