From 1a486813ef0bc7f7a2eb6eaeec2921fd71a2bd05 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 17 Jan 2023 20:21:01 -0500 Subject: Added switch-case support; fixed non-diff parameter transposition (#2596) --- source/slang/slang-ir-autodiff-fwd.cpp | 44 ++++++++++++++++++++ source/slang/slang-ir-autodiff-fwd.h | 2 + source/slang/slang-ir-autodiff-rev.cpp | 20 +++++----- source/slang/slang-ir-autodiff-transpose.h | 64 ++++++++++++++++++++++++++++-- source/slang/slang-ir-autodiff-unzip.h | 36 +++++++++++++++++ 5 files changed, 154 insertions(+), 12 deletions(-) (limited to 'source') diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index fa3eb463e..e8fe2beac 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -784,6 +784,47 @@ InstPair ForwardDiffTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* orig return InstPair(diffLoop, diffLoop); } +InstPair ForwardDiffTranscriber::transcribeSwitch(IRBuilder* builder, IRSwitch* origSwitch) +{ + // Transcribe condition (primal only, conditions do not produce differentials) + auto primalCondition = findOrTranscribePrimalInst(builder, origSwitch->getCondition()); + SLANG_ASSERT(primalCondition); + + // Transcribe 'default' block + IRBlock* diffDefaultBlock = as( + findOrTranscribeDiffInst(builder, origSwitch->getDefaultLabel())); + SLANG_ASSERT(diffDefaultBlock); + + // Transcribe 'default' block + IRBlock* diffBreakBlock = as( + findOrTranscribeDiffInst(builder, origSwitch->getBreakLabel())); + SLANG_ASSERT(diffBreakBlock); + + // Transcribe all other operands + List diffCaseValuesAndLabels; + for (UIndex ii = 0; ii < origSwitch->getCaseCount(); ii ++) + { + auto primalCaseValue = findOrTranscribePrimalInst(builder, origSwitch->getCaseValue(ii)); + SLANG_ASSERT(primalCaseValue); + + auto diffCaseBlock = findOrTranscribeDiffInst(builder, origSwitch->getCaseLabel(ii)); + SLANG_ASSERT(diffCaseBlock); + + diffCaseValuesAndLabels.add(primalCaseValue); + diffCaseValuesAndLabels.add(diffCaseBlock); + } + + auto diffSwitchInst = builder->emitSwitch( + primalCondition, + diffBreakBlock, + diffDefaultBlock, + diffCaseValuesAndLabels.getCount(), + diffCaseValuesAndLabels.getBuffer()); + builder->markInstAsMixedDifferential(diffSwitchInst); + + return InstPair(diffSwitchInst, diffSwitchInst); +} + InstPair ForwardDiffTranscriber::transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse) { // IfElse Statements come with 4 blocks. We transcribe each block into it's @@ -1123,6 +1164,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_ifElse: return transcribeIfElse(builder, as(origInst)); + + case kIROp_Switch: + return transcribeSwitch(builder, as(origInst)); case kIROp_MakeDifferentialPair: return transcribeMakeDifferentialPair(builder, as(origInst)); diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index 828916c01..b09b57974 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -65,6 +65,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse); + InstPair transcribeSwitch(IRBuilder* builder, IRSwitch* origSwitch); + InstPair transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst); InstPair transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 8d6419cf2..de4fbe182 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -666,14 +666,15 @@ namespace Slang builder->setInsertInto(fwdDiffParameterBlock); - // 1. Turn fwd-diff versions of the parameters into reverse-diff versions by wrapping them as InOutType<> - for (auto child = fwdDiffParameterBlock->getFirstParam(); child;) + List fwdParams; + for (auto child = fwdDiffParameterBlock->getFirstParam(); child; child = child->getNextParam()) { - IRParam* nextChild = child->getNextParam(); + fwdParams.add(child); + } - auto fwdParam = as(child); - SLANG_ASSERT(fwdParam); - + // 1. Turn fwd-diff versions of the parameters into reverse-diff versions by wrapping them as InOutType<> + for (auto fwdParam : fwdParams) + { // TODO: Handle ptr types. if (auto diffPairType = as(fwdParam->getDataType())) { @@ -690,10 +691,11 @@ namespace Slang else { // Default case (parameter has nothing to do with differentiation) - // Do nothing. + // Simply move the parameter to the end. + // + fwdParam->removeFromParent(); + fwdDiffParameterBlock->addParam(fwdParam); } - - child = nextChild; } auto paramCount = as(diffFunc->getDataType())->getParamCount(); diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index cbdb0a998..ae1a5dd70 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -858,7 +858,8 @@ struct DiffTransposePass // Check for predecessors count. for (auto predecessor : fwdBlockInst->getPredecessors()) { - fwdPredecesorBlocks.add(predecessor); + if (!fwdPredecesorBlocks.contains(predecessor)) + fwdPredecesorBlocks.add(predecessor); } SLANG_ASSERT(fwdPredecesorBlocks.getCount() > 0); @@ -915,7 +916,7 @@ struct DiffTransposePass { IRBlock* revBlock = revBlockMap[fwdBlockInst]; - // If we already have a terminator, we've resolved it during + // If we already have a terminator, we've probably resolved it during // tryEmitTerminator() // if (revBlock->getTerminator() != nullptr) @@ -960,7 +961,63 @@ struct DiffTransposePass } builder->emitIfElse(condition, revTrueBlock, revFalseBlock, revAfterBlock); - break; + return true; + } + case kIROp_Switch: + { + auto switchInst = as(terminatorInst); + + auto condition = switchInst->getCondition(); + SLANG_ASSERT(!isDifferentialInst(condition)); + + // fwd origin block is the reverse 'break' block. + auto revAfterBlock = as( + revBlockMap[as(switchInst->getParent())]); + + // Find regions for every branch, and find the reverse-mode + // version of the each exit block. + Region* defaultRegion = regionMap[switchInst->getDefaultLabel()]; + IRBlock* revDefaultBlock = revBlockMap[defaultRegion->exitBlock]; + + List revCaseBlocks; + for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii ++) + { + Region* caseRegion = regionMap[switchInst->getCaseLabel(ii)]; + IRBlock* revCaseBlock = revBlockMap[caseRegion->exitBlock]; + revCaseBlocks.add(revCaseBlock); + } + + // If we have phi derivatives to pass on, + // we need to add dummy blocks to pass them using + // an unconditional branch. + // + if (phiParamGrads.getCount() > 0) + { + revDefaultBlock = insertPhiBlockBefore(revDefaultBlock, phiParamGrads); + revDefaultBlock->insertAfter(revBlock); + + for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii ++) + { + revCaseBlocks[ii] = insertPhiBlockBefore(revCaseBlocks[ii], phiParamGrads); + revCaseBlocks[ii]->insertAfter(revBlock); + } + } + + List revCaseArgs; + for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii ++) + { + revCaseArgs.add(switchInst->getCaseValue(ii)); + revCaseArgs.add(revCaseBlocks[ii]); + } + + builder->emitSwitch( + condition, + revAfterBlock, + revDefaultBlock, + revCaseArgs.getCount(), + revCaseArgs.getBuffer()); + + return true; } default: SLANG_UNIMPLEMENTED_X("Unhandled control flow inst during transposition"); @@ -1010,6 +1067,7 @@ struct DiffTransposePass case kIROp_conditionalBranch: case kIROp_ifElse: case kIROp_loop: + case kIROp_Switch: { // Ignore. transposeBlock() should take care of adding the // appropriate branch instruction. diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index b06ed29bf..e616578c1 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -425,6 +425,40 @@ struct DiffUnzipPass as(diffMap[afterBlock]))); } + case kIROp_Switch: + { + auto switchInst = as(branchInst); + auto breakBlock = switchInst->getBreakLabel(); + auto defaultBlock = switchInst->getDefaultLabel(); + auto condInst = switchInst->getCondition(); + + List primalCaseArgs; + List diffCaseArgs; + + for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii ++) + { + primalCaseArgs.add(switchInst->getCaseValue(ii)); + diffCaseArgs.add(switchInst->getCaseValue(ii)); + + primalCaseArgs.add(primalMap[switchInst->getCaseLabel(ii)]); + diffCaseArgs.add(diffMap[switchInst->getCaseLabel(ii)]); + } + + return InstPair( + primalBuilder->emitSwitch( + condInst, + as(primalMap[breakBlock]), + as(primalMap[defaultBlock]), + primalCaseArgs.getCount(), + primalCaseArgs.getBuffer()), + diffBuilder->emitSwitch( + condInst, + as(diffMap[breakBlock]), + as(diffMap[defaultBlock]), + diffCaseArgs.getCount(), + diffCaseArgs.getBuffer())); + } + default: SLANG_UNEXPECTED("Unhandled instruction"); } @@ -452,6 +486,8 @@ struct DiffUnzipPass case kIROp_unconditionalBranch: case kIROp_conditionalBranch: case kIROp_ifElse: + case kIROp_Switch: + case kIROp_loop: return splitControlFlow(primalBuilder, diffBuilder, inst); case kIROp_Unreachable: -- cgit v1.2.3