diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-01-17 20:21:01 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-17 17:21:01 -0800 |
| commit | 1a486813ef0bc7f7a2eb6eaeec2921fd71a2bd05 (patch) | |
| tree | 5d64c73b7859a1657f71af95da7bc9e78fc58bf2 /source | |
| parent | 2c437498d3a09b58de17a8865242814d9ea92fde (diff) | |
Added switch-case support; fixed non-diff parameter transposition (#2596)
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 44 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 20 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 64 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 36 |
5 files changed, 154 insertions, 12 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index fa3eb463e..e8fe2beac 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -784,6 +784,47 @@ InstPair ForwardDiffTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* orig return InstPair(diffLoop, diffLoop); } +InstPair ForwardDiffTranscriber::transcribeSwitch(IRBuilder* builder, IRSwitch* origSwitch) +{ + // Transcribe condition (primal only, conditions do not produce differentials) + auto primalCondition = findOrTranscribePrimalInst(builder, origSwitch->getCondition()); + SLANG_ASSERT(primalCondition); + + // Transcribe 'default' block + IRBlock* diffDefaultBlock = as<IRBlock>( + findOrTranscribeDiffInst(builder, origSwitch->getDefaultLabel())); + SLANG_ASSERT(diffDefaultBlock); + + // Transcribe 'default' block + IRBlock* diffBreakBlock = as<IRBlock>( + findOrTranscribeDiffInst(builder, origSwitch->getBreakLabel())); + SLANG_ASSERT(diffBreakBlock); + + // Transcribe all other operands + List<IRInst*> diffCaseValuesAndLabels; + for (UIndex ii = 0; ii < origSwitch->getCaseCount(); ii ++) + { + auto primalCaseValue = findOrTranscribePrimalInst(builder, origSwitch->getCaseValue(ii)); + SLANG_ASSERT(primalCaseValue); + + auto diffCaseBlock = findOrTranscribeDiffInst(builder, origSwitch->getCaseLabel(ii)); + SLANG_ASSERT(diffCaseBlock); + + diffCaseValuesAndLabels.add(primalCaseValue); + diffCaseValuesAndLabels.add(diffCaseBlock); + } + + auto diffSwitchInst = builder->emitSwitch( + primalCondition, + diffBreakBlock, + diffDefaultBlock, + diffCaseValuesAndLabels.getCount(), + diffCaseValuesAndLabels.getBuffer()); + builder->markInstAsMixedDifferential(diffSwitchInst); + + return InstPair(diffSwitchInst, diffSwitchInst); +} + InstPair ForwardDiffTranscriber::transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse) { // IfElse Statements come with 4 blocks. We transcribe each block into it's @@ -1123,6 +1164,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_ifElse: return transcribeIfElse(builder, as<IRIfElse>(origInst)); + + case kIROp_Switch: + return transcribeSwitch(builder, as<IRSwitch>(origInst)); case kIROp_MakeDifferentialPair: return transcribeMakeDifferentialPair(builder, as<IRMakeDifferentialPair>(origInst)); diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index 828916c01..b09b57974 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -65,6 +65,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse); + InstPair transcribeSwitch(IRBuilder* builder, IRSwitch* origSwitch); + InstPair transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst); InstPair transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 8d6419cf2..de4fbe182 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -666,14 +666,15 @@ namespace Slang builder->setInsertInto(fwdDiffParameterBlock); - // 1. Turn fwd-diff versions of the parameters into reverse-diff versions by wrapping them as InOutType<> - for (auto child = fwdDiffParameterBlock->getFirstParam(); child;) + List<IRParam*> fwdParams; + for (auto child = fwdDiffParameterBlock->getFirstParam(); child; child = child->getNextParam()) { - IRParam* nextChild = child->getNextParam(); + fwdParams.add(child); + } - auto fwdParam = as<IRParam>(child); - SLANG_ASSERT(fwdParam); - + // 1. Turn fwd-diff versions of the parameters into reverse-diff versions by wrapping them as InOutType<> + for (auto fwdParam : fwdParams) + { // TODO: Handle ptr<pair> types. if (auto diffPairType = as<IRDifferentialPairType>(fwdParam->getDataType())) { @@ -690,10 +691,11 @@ namespace Slang else { // Default case (parameter has nothing to do with differentiation) - // Do nothing. + // Simply move the parameter to the end. + // + fwdParam->removeFromParent(); + fwdDiffParameterBlock->addParam(fwdParam); } - - child = nextChild; } auto paramCount = as<IRFuncType>(diffFunc->getDataType())->getParamCount(); diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index cbdb0a998..ae1a5dd70 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -858,7 +858,8 @@ struct DiffTransposePass // Check for predecessors count. for (auto predecessor : fwdBlockInst->getPredecessors()) { - fwdPredecesorBlocks.add(predecessor); + if (!fwdPredecesorBlocks.contains(predecessor)) + fwdPredecesorBlocks.add(predecessor); } SLANG_ASSERT(fwdPredecesorBlocks.getCount() > 0); @@ -915,7 +916,7 @@ struct DiffTransposePass { IRBlock* revBlock = revBlockMap[fwdBlockInst]; - // If we already have a terminator, we've resolved it during + // If we already have a terminator, we've probably resolved it during // tryEmitTerminator() // if (revBlock->getTerminator() != nullptr) @@ -960,7 +961,63 @@ struct DiffTransposePass } builder->emitIfElse(condition, revTrueBlock, revFalseBlock, revAfterBlock); - break; + return true; + } + case kIROp_Switch: + { + auto switchInst = as<IRSwitch>(terminatorInst); + + auto condition = switchInst->getCondition(); + SLANG_ASSERT(!isDifferentialInst(condition)); + + // fwd origin block is the reverse 'break' block. + auto revAfterBlock = as<IRBlock>( + revBlockMap[as<IRBlock>(switchInst->getParent())]); + + // Find regions for every branch, and find the reverse-mode + // version of the each exit block. + Region* defaultRegion = regionMap[switchInst->getDefaultLabel()]; + IRBlock* revDefaultBlock = revBlockMap[defaultRegion->exitBlock]; + + List<IRBlock*> revCaseBlocks; + for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii ++) + { + Region* caseRegion = regionMap[switchInst->getCaseLabel(ii)]; + IRBlock* revCaseBlock = revBlockMap[caseRegion->exitBlock]; + revCaseBlocks.add(revCaseBlock); + } + + // If we have phi derivatives to pass on, + // we need to add dummy blocks to pass them using + // an unconditional branch. + // + if (phiParamGrads.getCount() > 0) + { + revDefaultBlock = insertPhiBlockBefore(revDefaultBlock, phiParamGrads); + revDefaultBlock->insertAfter(revBlock); + + for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii ++) + { + revCaseBlocks[ii] = insertPhiBlockBefore(revCaseBlocks[ii], phiParamGrads); + revCaseBlocks[ii]->insertAfter(revBlock); + } + } + + List<IRInst*> revCaseArgs; + for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii ++) + { + revCaseArgs.add(switchInst->getCaseValue(ii)); + revCaseArgs.add(revCaseBlocks[ii]); + } + + builder->emitSwitch( + condition, + revAfterBlock, + revDefaultBlock, + revCaseArgs.getCount(), + revCaseArgs.getBuffer()); + + return true; } default: SLANG_UNIMPLEMENTED_X("Unhandled control flow inst during transposition"); @@ -1010,6 +1067,7 @@ struct DiffTransposePass case kIROp_conditionalBranch: case kIROp_ifElse: case kIROp_loop: + case kIROp_Switch: { // Ignore. transposeBlock() should take care of adding the // appropriate branch instruction. diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index b06ed29bf..e616578c1 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -425,6 +425,40 @@ struct DiffUnzipPass as<IRBlock>(diffMap[afterBlock]))); } + case kIROp_Switch: + { + auto switchInst = as<IRSwitch>(branchInst); + auto breakBlock = switchInst->getBreakLabel(); + auto defaultBlock = switchInst->getDefaultLabel(); + auto condInst = switchInst->getCondition(); + + List<IRInst*> primalCaseArgs; + List<IRInst*> diffCaseArgs; + + for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii ++) + { + primalCaseArgs.add(switchInst->getCaseValue(ii)); + diffCaseArgs.add(switchInst->getCaseValue(ii)); + + primalCaseArgs.add(primalMap[switchInst->getCaseLabel(ii)]); + diffCaseArgs.add(diffMap[switchInst->getCaseLabel(ii)]); + } + + return InstPair( + primalBuilder->emitSwitch( + condInst, + as<IRBlock>(primalMap[breakBlock]), + as<IRBlock>(primalMap[defaultBlock]), + primalCaseArgs.getCount(), + primalCaseArgs.getBuffer()), + diffBuilder->emitSwitch( + condInst, + as<IRBlock>(diffMap[breakBlock]), + as<IRBlock>(diffMap[defaultBlock]), + diffCaseArgs.getCount(), + diffCaseArgs.getBuffer())); + } + default: SLANG_UNEXPECTED("Unhandled instruction"); } @@ -452,6 +486,8 @@ struct DiffUnzipPass case kIROp_unconditionalBranch: case kIROp_conditionalBranch: case kIROp_ifElse: + case kIROp_Switch: + case kIROp_loop: return splitControlFlow(primalBuilder, diffBuilder, inst); case kIROp_Unreachable: |
