diff options
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 44 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 20 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 64 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 36 | ||||
| -rw-r--r-- | tests/autodiff/reverse-switch-case.slang | 55 | ||||
| -rw-r--r-- | tests/autodiff/reverse-switch-case.slang.expected.txt | 6 |
7 files changed, 215 insertions, 12 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index fa3eb463e..e8fe2beac 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -784,6 +784,47 @@ InstPair ForwardDiffTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* orig return InstPair(diffLoop, diffLoop); } +InstPair ForwardDiffTranscriber::transcribeSwitch(IRBuilder* builder, IRSwitch* origSwitch) +{ + // Transcribe condition (primal only, conditions do not produce differentials) + auto primalCondition = findOrTranscribePrimalInst(builder, origSwitch->getCondition()); + SLANG_ASSERT(primalCondition); + + // Transcribe 'default' block + IRBlock* diffDefaultBlock = as<IRBlock>( + findOrTranscribeDiffInst(builder, origSwitch->getDefaultLabel())); + SLANG_ASSERT(diffDefaultBlock); + + // Transcribe 'default' block + IRBlock* diffBreakBlock = as<IRBlock>( + findOrTranscribeDiffInst(builder, origSwitch->getBreakLabel())); + SLANG_ASSERT(diffBreakBlock); + + // Transcribe all other operands + List<IRInst*> diffCaseValuesAndLabels; + for (UIndex ii = 0; ii < origSwitch->getCaseCount(); ii ++) + { + auto primalCaseValue = findOrTranscribePrimalInst(builder, origSwitch->getCaseValue(ii)); + SLANG_ASSERT(primalCaseValue); + + auto diffCaseBlock = findOrTranscribeDiffInst(builder, origSwitch->getCaseLabel(ii)); + SLANG_ASSERT(diffCaseBlock); + + diffCaseValuesAndLabels.add(primalCaseValue); + diffCaseValuesAndLabels.add(diffCaseBlock); + } + + auto diffSwitchInst = builder->emitSwitch( + primalCondition, + diffBreakBlock, + diffDefaultBlock, + diffCaseValuesAndLabels.getCount(), + diffCaseValuesAndLabels.getBuffer()); + builder->markInstAsMixedDifferential(diffSwitchInst); + + return InstPair(diffSwitchInst, diffSwitchInst); +} + InstPair ForwardDiffTranscriber::transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse) { // IfElse Statements come with 4 blocks. We transcribe each block into it's @@ -1123,6 +1164,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_ifElse: return transcribeIfElse(builder, as<IRIfElse>(origInst)); + + case kIROp_Switch: + return transcribeSwitch(builder, as<IRSwitch>(origInst)); case kIROp_MakeDifferentialPair: return transcribeMakeDifferentialPair(builder, as<IRMakeDifferentialPair>(origInst)); diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index 828916c01..b09b57974 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -65,6 +65,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse); + InstPair transcribeSwitch(IRBuilder* builder, IRSwitch* origSwitch); + InstPair transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst); InstPair transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 8d6419cf2..de4fbe182 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -666,14 +666,15 @@ namespace Slang builder->setInsertInto(fwdDiffParameterBlock); - // 1. Turn fwd-diff versions of the parameters into reverse-diff versions by wrapping them as InOutType<> - for (auto child = fwdDiffParameterBlock->getFirstParam(); child;) + List<IRParam*> fwdParams; + for (auto child = fwdDiffParameterBlock->getFirstParam(); child; child = child->getNextParam()) { - IRParam* nextChild = child->getNextParam(); + fwdParams.add(child); + } - auto fwdParam = as<IRParam>(child); - SLANG_ASSERT(fwdParam); - + // 1. Turn fwd-diff versions of the parameters into reverse-diff versions by wrapping them as InOutType<> + for (auto fwdParam : fwdParams) + { // TODO: Handle ptr<pair> types. if (auto diffPairType = as<IRDifferentialPairType>(fwdParam->getDataType())) { @@ -690,10 +691,11 @@ namespace Slang else { // Default case (parameter has nothing to do with differentiation) - // Do nothing. + // Simply move the parameter to the end. + // + fwdParam->removeFromParent(); + fwdDiffParameterBlock->addParam(fwdParam); } - - child = nextChild; } auto paramCount = as<IRFuncType>(diffFunc->getDataType())->getParamCount(); diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index cbdb0a998..ae1a5dd70 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -858,7 +858,8 @@ struct DiffTransposePass // Check for predecessors count. for (auto predecessor : fwdBlockInst->getPredecessors()) { - fwdPredecesorBlocks.add(predecessor); + if (!fwdPredecesorBlocks.contains(predecessor)) + fwdPredecesorBlocks.add(predecessor); } SLANG_ASSERT(fwdPredecesorBlocks.getCount() > 0); @@ -915,7 +916,7 @@ struct DiffTransposePass { IRBlock* revBlock = revBlockMap[fwdBlockInst]; - // If we already have a terminator, we've resolved it during + // If we already have a terminator, we've probably resolved it during // tryEmitTerminator() // if (revBlock->getTerminator() != nullptr) @@ -960,7 +961,63 @@ struct DiffTransposePass } builder->emitIfElse(condition, revTrueBlock, revFalseBlock, revAfterBlock); - break; + return true; + } + case kIROp_Switch: + { + auto switchInst = as<IRSwitch>(terminatorInst); + + auto condition = switchInst->getCondition(); + SLANG_ASSERT(!isDifferentialInst(condition)); + + // fwd origin block is the reverse 'break' block. + auto revAfterBlock = as<IRBlock>( + revBlockMap[as<IRBlock>(switchInst->getParent())]); + + // Find regions for every branch, and find the reverse-mode + // version of the each exit block. + Region* defaultRegion = regionMap[switchInst->getDefaultLabel()]; + IRBlock* revDefaultBlock = revBlockMap[defaultRegion->exitBlock]; + + List<IRBlock*> revCaseBlocks; + for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii ++) + { + Region* caseRegion = regionMap[switchInst->getCaseLabel(ii)]; + IRBlock* revCaseBlock = revBlockMap[caseRegion->exitBlock]; + revCaseBlocks.add(revCaseBlock); + } + + // If we have phi derivatives to pass on, + // we need to add dummy blocks to pass them using + // an unconditional branch. + // + if (phiParamGrads.getCount() > 0) + { + revDefaultBlock = insertPhiBlockBefore(revDefaultBlock, phiParamGrads); + revDefaultBlock->insertAfter(revBlock); + + for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii ++) + { + revCaseBlocks[ii] = insertPhiBlockBefore(revCaseBlocks[ii], phiParamGrads); + revCaseBlocks[ii]->insertAfter(revBlock); + } + } + + List<IRInst*> revCaseArgs; + for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii ++) + { + revCaseArgs.add(switchInst->getCaseValue(ii)); + revCaseArgs.add(revCaseBlocks[ii]); + } + + builder->emitSwitch( + condition, + revAfterBlock, + revDefaultBlock, + revCaseArgs.getCount(), + revCaseArgs.getBuffer()); + + return true; } default: SLANG_UNIMPLEMENTED_X("Unhandled control flow inst during transposition"); @@ -1010,6 +1067,7 @@ struct DiffTransposePass case kIROp_conditionalBranch: case kIROp_ifElse: case kIROp_loop: + case kIROp_Switch: { // Ignore. transposeBlock() should take care of adding the // appropriate branch instruction. diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index b06ed29bf..e616578c1 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -425,6 +425,40 @@ struct DiffUnzipPass as<IRBlock>(diffMap[afterBlock]))); } + case kIROp_Switch: + { + auto switchInst = as<IRSwitch>(branchInst); + auto breakBlock = switchInst->getBreakLabel(); + auto defaultBlock = switchInst->getDefaultLabel(); + auto condInst = switchInst->getCondition(); + + List<IRInst*> primalCaseArgs; + List<IRInst*> diffCaseArgs; + + for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii ++) + { + primalCaseArgs.add(switchInst->getCaseValue(ii)); + diffCaseArgs.add(switchInst->getCaseValue(ii)); + + primalCaseArgs.add(primalMap[switchInst->getCaseLabel(ii)]); + diffCaseArgs.add(diffMap[switchInst->getCaseLabel(ii)]); + } + + return InstPair( + primalBuilder->emitSwitch( + condInst, + as<IRBlock>(primalMap[breakBlock]), + as<IRBlock>(primalMap[defaultBlock]), + primalCaseArgs.getCount(), + primalCaseArgs.getBuffer()), + diffBuilder->emitSwitch( + condInst, + as<IRBlock>(diffMap[breakBlock]), + as<IRBlock>(diffMap[defaultBlock]), + diffCaseArgs.getCount(), + diffCaseArgs.getBuffer())); + } + default: SLANG_UNEXPECTED("Unhandled instruction"); } @@ -452,6 +486,8 @@ struct DiffUnzipPass case kIROp_unconditionalBranch: case kIROp_conditionalBranch: case kIROp_ifElse: + case kIROp_Switch: + case kIROp_loop: return splitControlFlow(primalBuilder, diffBuilder, inst); case kIROp_Unreachable: diff --git a/tests/autodiff/reverse-switch-case.slang b/tests/autodiff/reverse-switch-case.slang new file mode 100644 index 000000000..21a7565af --- /dev/null +++ b/tests/autodiff/reverse-switch-case.slang @@ -0,0 +1,55 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; +typedef float.Differential dfloat; + +[BackwardDifferentiable] +float test_simple_switch(float y, int i) +{ + float o; + switch (i) + { + case 0: + case 1: + o = y * 2.0; + break; + + case 2: + o = y * 3.0; + break; + + default: + o = y; + } + + return o; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(1.0, 0.0); + + __bwd_diff(test_simple_switch)(dpa, 1, 1.0f); + outputBuffer[0] = dpa.d; // Expect: 2.0 + } + + { + dpfloat dpa = dpfloat(0.4, 0.0); + + __bwd_diff(test_simple_switch)(dpa, 2, 1.0f); + outputBuffer[1] = dpa.d; // Expect: 3.0 + } + + { + dpfloat dpa = dpfloat(1.0, 0.0); + + __bwd_diff(test_simple_switch)(dpa, 3, 1.0f); + outputBuffer[2] = dpa.d; // Expect: 1.0 + } +} diff --git a/tests/autodiff/reverse-switch-case.slang.expected.txt b/tests/autodiff/reverse-switch-case.slang.expected.txt new file mode 100644 index 000000000..6e3b3dd12 --- /dev/null +++ b/tests/autodiff/reverse-switch-case.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +2.000000 +3.000000 +1.000000 +0.000000 +0.000000 |
